From de18c4f76a41ff2fa5758a36f6d21d2358c6851b Mon Sep 17 00:00:00 2001 From: Mikoto <60188643+avdb13@users.noreply.github.com> Date: Thu, 21 Nov 2024 06:29:42 +0000 Subject: [PATCH] feat: Extract common abstractions (#244) * initialize crate * add `Cached` trait * add `Throttled` trait * add resolvers * add store * add workflows * fix `atrium-oauth` * add error conversions * change type visibility * fix identity crate * fix oauth-client crate * small fix * mofify crate authors * change `Resolver` type signature * apply suggestions * fix `Throttled` tests * fix wasm `CacheTrait` --- .github/workflows/common.yml | 19 ++ .github/workflows/wasm.yml | 1 + Cargo.toml | 2 + atrium-common/Cargo.toml | 36 +++ atrium-common/src/lib.rs | 3 + atrium-common/src/resolver.rs | 222 ++++++++++++++++ atrium-common/src/resolver/cached.rs | 31 +++ atrium-common/src/resolver/throttled.rs | 43 +++ atrium-common/src/store.rs | 19 ++ atrium-common/src/store/memory.rs | 46 ++++ atrium-common/src/types.rs | 2 + atrium-common/src/types/cached.rs | 34 +++ atrium-common/src/types/cached/impl.rs | 24 ++ .../src/types/cached/impl}/moka.rs | 4 +- .../src/types/cached/impl}/wasm.rs | 4 +- atrium-common/src/types/throttled.rs | 31 +++ atrium-oauth/identity/Cargo.toml | 21 +- atrium-oauth/identity/src/did.rs | 6 +- .../identity/src/did/common_resolver.rs | 3 +- atrium-oauth/identity/src/did/plc_resolver.rs | 3 +- atrium-oauth/identity/src/did/web_resolver.rs | 3 +- atrium-oauth/identity/src/handle.rs | 6 +- .../identity/src/handle/appview_resolver.rs | 3 +- .../identity/src/handle/atproto_resolver.rs | 4 +- .../identity/src/handle/dns_resolver.rs | 3 +- .../src/handle/well_known_resolver.rs | 3 +- .../identity/src/identity_resolver.rs | 4 +- atrium-oauth/identity/src/lib.rs | 1 - atrium-oauth/identity/src/resolver.rs | 246 ------------------ .../identity/src/resolver/cache_impl.rs | 9 - .../identity/src/resolver/cached_resolver.rs | 61 ----- .../src/resolver/throttled_resolver.rs | 59 ----- atrium-oauth/oauth-client/Cargo.toml | 1 + atrium-oauth/oauth-client/src/oauth_client.rs | 3 +- atrium-oauth/oauth-client/src/resolver.rs | 33 ++- .../oauth_authorization_server_resolver.rs | 4 +- .../oauth_protected_resource_resolver.rs | 4 +- 37 files changed, 574 insertions(+), 427 deletions(-) create mode 100644 .github/workflows/common.yml create mode 100644 atrium-common/Cargo.toml create mode 100644 atrium-common/src/lib.rs create mode 100644 atrium-common/src/resolver.rs create mode 100644 atrium-common/src/resolver/cached.rs create mode 100644 atrium-common/src/resolver/throttled.rs create mode 100644 atrium-common/src/store.rs create mode 100644 atrium-common/src/store/memory.rs create mode 100644 atrium-common/src/types.rs create mode 100644 atrium-common/src/types/cached.rs create mode 100644 atrium-common/src/types/cached/impl.rs rename {atrium-oauth/identity/src/resolver/cache_impl => atrium-common/src/types/cached/impl}/moka.rs (88%) rename {atrium-oauth/identity/src/resolver/cache_impl => atrium-common/src/types/cached/impl}/wasm.rs (95%) create mode 100644 atrium-common/src/types/throttled.rs delete mode 100644 atrium-oauth/identity/src/resolver/cache_impl.rs delete mode 100644 atrium-oauth/identity/src/resolver/cached_resolver.rs delete mode 100644 atrium-oauth/identity/src/resolver/throttled_resolver.rs diff --git a/.github/workflows/common.yml b/.github/workflows/common.yml new file mode 100644 index 00000000..284712ab --- /dev/null +++ b/.github/workflows/common.yml @@ -0,0 +1,19 @@ +name: Common +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] +env: + CARGO_TERM_COLOR: always +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Build + run: | + cargo build -p atrium-common --verbose + - name: Run tests + run: | + cargo test -p atrium-common --lib diff --git a/.github/workflows/wasm.yml b/.github/workflows/wasm.yml index 472e97de..4f84f176 100644 --- a/.github/workflows/wasm.yml +++ b/.github/workflows/wasm.yml @@ -67,3 +67,4 @@ jobs: - run: wasm-pack test --node atrium-xrpc - run: wasm-pack test --node atrium-xrpc-client - run: wasm-pack test --node atrium-oauth/identity + - run: wasm-pack test --node atrium-common diff --git a/Cargo.toml b/Cargo.toml index 726d8dcd..7623fe3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "atrium-api", + "atrium-common", "atrium-crypto", "atrium-xrpc", "atrium-xrpc-client", @@ -26,6 +27,7 @@ keywords = ["atproto", "bluesky"] [workspace.dependencies] # Intra-workspace dependencies atrium-api = { version = "0.24.8", path = "atrium-api", default-features = false } +atrium-common = { version = "0.1.0", path = "atrium-common" } atrium-identity = { version = "0.1.0", path = "atrium-oauth/identity" } atrium-xrpc = { version = "0.12.0", path = "atrium-xrpc" } atrium-xrpc-client = { version = "0.5.10", path = "atrium-xrpc-client" } diff --git a/atrium-common/Cargo.toml b/atrium-common/Cargo.toml new file mode 100644 index 00000000..9bda3a56 --- /dev/null +++ b/atrium-common/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "atrium-common" +version = "0.1.0" +authors = ["sugyan ", "avdb13 "] +edition.workspace = true +rust-version.workspace = true +description = "Utility library for common abstractions in atproto" +documentation = "https://docs.rs/atrium-common" +readme = "README.md" +repository.workspace = true +license.workspace = true +keywords = ["atproto", "bluesky"] + +[dependencies] +dashmap.workspace = true +thiserror.workspace = true +tokio = { workspace = true, default-features = false, features = ["sync"] } +trait-variant.workspace = true + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +moka = { workspace = true, features = ["future"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +lru.workspace = true +web-time.workspace = true + +[dev-dependencies] +futures.workspace = true + +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } + +[target.'cfg(target_arch = "wasm32")'.dev-dependencies] +gloo-timers.workspace = true +tokio = { workspace = true, features = ["time"] } +wasm-bindgen-test.workspace = true diff --git a/atrium-common/src/lib.rs b/atrium-common/src/lib.rs new file mode 100644 index 00000000..8a69602e --- /dev/null +++ b/atrium-common/src/lib.rs @@ -0,0 +1,3 @@ +pub mod resolver; +pub mod store; +pub mod types; diff --git a/atrium-common/src/resolver.rs b/atrium-common/src/resolver.rs new file mode 100644 index 00000000..72e3e26a --- /dev/null +++ b/atrium-common/src/resolver.rs @@ -0,0 +1,222 @@ +mod cached; +mod throttled; + +pub use self::cached::CachedResolver; +pub use self::throttled::ThrottledResolver; +use std::future::Future; + +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] +pub trait Resolver { + type Input: ?Sized; + type Output; + type Error; + + fn resolve( + &self, + input: &Self::Input, + ) -> impl Future>; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::cached::r#impl::{Cache, CacheImpl}; + use crate::types::cached::{CacheConfig, Cacheable}; + use crate::types::throttled::Throttleable; + use std::collections::HashMap; + use std::sync::Arc; + use std::time::Duration; + use tokio::sync::RwLock; + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::wasm_bindgen_test; + + #[cfg(not(target_arch = "wasm32"))] + async fn sleep(duration: Duration) { + tokio::time::sleep(duration).await; + } + + #[cfg(target_arch = "wasm32")] + async fn sleep(duration: Duration) { + gloo_timers::future::sleep(duration).await; + } + + #[derive(Debug, PartialEq)] + struct Error; + + type Result = core::result::Result; + + struct MockResolver { + data: HashMap, + counts: Arc>>, + } + + impl Resolver for MockResolver { + type Input = String; + type Output = String; + type Error = Error; + + async fn resolve(&self, input: &Self::Input) -> Result { + sleep(Duration::from_millis(10)).await; + *self.counts.write().await.entry(input.clone()).or_default() += 1; + if let Some(value) = self.data.get(input) { + Ok(value.clone()) + } else { + Err(Error) + } + } + } + + fn mock_resolver(counts: Arc>>) -> MockResolver { + MockResolver { + data: [ + (String::from("k1"), String::from("v1")), + (String::from("k2"), String::from("v2")), + ] + .into_iter() + .collect(), + counts, + } + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + async fn test_no_cached() { + let counts = Arc::new(RwLock::new(HashMap::new())); + let resolver = mock_resolver(counts.clone()); + for (input, expected) in [ + ("k1", Some("v1")), + ("k2", Some("v2")), + ("k2", Some("v2")), + ("k1", Some("v1")), + ("k3", None), + ("k1", Some("v1")), + ("k3", None), + ] { + let result = resolver.resolve(&input.to_string()).await; + match expected { + Some(value) => assert_eq!(result.expect("failed to resolve"), value), + None => assert_eq!(result.expect_err("succesfully resolved"), Error), + } + } + assert_eq!( + *counts.read().await, + [(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),] + .into_iter() + .collect() + ); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + async fn test_cached() { + let counts = Arc::new(RwLock::new(HashMap::new())); + let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default())); + for (input, expected) in [ + ("k1", Some("v1")), + ("k2", Some("v2")), + ("k2", Some("v2")), + ("k1", Some("v1")), + ("k3", None), + ("k1", Some("v1")), + ("k3", None), + ] { + let result = resolver.resolve(&input.to_string()).await; + match expected { + Some(value) => assert_eq!(result.expect("failed to resolve"), value), + None => assert_eq!(result.expect_err("succesfully resolved"), Error), + } + } + assert_eq!( + *counts.read().await, + [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),] + .into_iter() + .collect() + ); + } + + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_cached_with_max_capacity() { + let counts = Arc::new(RwLock::new(HashMap::new())); + let resolver = mock_resolver(counts.clone()) + .cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() })); + for (input, expected) in [ + ("k1", Some("v1")), + ("k2", Some("v2")), + ("k2", Some("v2")), + ("k1", Some("v1")), + ("k3", None), + ("k1", Some("v1")), + ("k3", None), + ] { + let result = resolver.resolve(&input.to_string()).await; + match expected { + Some(value) => assert_eq!(result.expect("failed to resolve"), value), + None => assert_eq!(result.expect_err("succesfully resolved"), Error), + } + } + assert_eq!( + *counts.read().await, + [(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),] + .into_iter() + .collect() + ); + } + + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_cached_with_time_to_live() { + let counts = Arc::new(RwLock::new(HashMap::new())); + let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig { + time_to_live: Some(Duration::from_millis(10)), + ..Default::default() + })); + for _ in 0..10 { + let result = resolver.resolve(&String::from("k1")).await; + assert_eq!(result.expect("failed to resolve"), "v1"); + } + sleep(Duration::from_millis(10)).await; + for _ in 0..10 { + let result = resolver.resolve(&String::from("k1")).await; + assert_eq!(result.expect("failed to resolve"), "v1"); + } + assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect()); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + async fn test_throttled() { + let counts = Arc::new(RwLock::new(HashMap::new())); + let resolver = Arc::new(mock_resolver(counts.clone()).throttled()); + + let mut handles = Vec::new(); + for (input, expected) in [ + ("k1", Some("v1")), + ("k2", Some("v2")), + ("k2", Some("v2")), + ("k1", Some("v1")), + ("k3", None), + ("k1", Some("v1")), + ("k3", None), + ] { + let resolver = resolver.clone(); + handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) }); + } + for (result, expected) in futures::future::join_all(handles).await { + let result = result.and_then(|opt| opt.ok_or(Error)); + + match expected { + Some(value) => { + assert_eq!(result.expect("failed to resolve"), value) + } + None => assert_eq!(result.expect_err("succesfully resolved"), Error), + } + } + assert_eq!( + *counts.read().await, + [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),] + .into_iter() + .collect() + ); + } +} diff --git a/atrium-common/src/resolver/cached.rs b/atrium-common/src/resolver/cached.rs new file mode 100644 index 00000000..6f55f56d --- /dev/null +++ b/atrium-common/src/resolver/cached.rs @@ -0,0 +1,31 @@ +use std::hash::Hash; + +use crate::types::cached::r#impl::{Cache, CacheImpl}; +use crate::types::cached::Cached; + +use super::Resolver; + +pub type CachedResolver = Cached::Input, ::Output>>; + +impl Resolver for Cached +where + R: Resolver + Send + Sync + 'static, + R::Input: Clone + Hash + Eq + Send + Sync + 'static, + R::Output: Clone + Send + Sync + 'static, + C: Cache + Send + Sync + 'static, + C::Input: Clone + Hash + Eq + Send + Sync + 'static, + C::Output: Clone + Send + Sync + 'static, +{ + type Input = R::Input; + type Output = R::Output; + type Error = R::Error; + + async fn resolve(&self, input: &Self::Input) -> Result { + if let Some(output) = self.cache.get(input).await { + return Ok(output); + } + let output = self.inner.resolve(input).await?; + self.cache.set(input.clone(), output.clone()).await; + Ok(output) + } +} diff --git a/atrium-common/src/resolver/throttled.rs b/atrium-common/src/resolver/throttled.rs new file mode 100644 index 00000000..f16893d4 --- /dev/null +++ b/atrium-common/src/resolver/throttled.rs @@ -0,0 +1,43 @@ +use std::{hash::Hash, sync::Arc}; + +use dashmap::{DashMap, Entry}; +use tokio::sync::broadcast::{channel, Sender}; +use tokio::sync::Mutex; + +use crate::types::throttled::Throttled; + +use super::Resolver; + +pub type SenderMap = + DashMap<::Input, Arc::Output>>>>>; + +pub type ThrottledResolver = Throttled>; + +impl Resolver for Throttled> +where + R: Resolver + Send + Sync + 'static, + R::Input: Clone + Hash + Eq + Send + Sync + 'static, + R::Output: Clone + Send + Sync + 'static, +{ + type Input = R::Input; + type Output = Option; + type Error = R::Error; + + async fn resolve(&self, input: &Self::Input) -> Result { + match self.pending.entry(input.clone()) { + Entry::Occupied(occupied) => { + let tx = occupied.get().lock().await.clone(); + drop(occupied); + Ok(tx.subscribe().recv().await.expect("recv")) + } + Entry::Vacant(vacant) => { + let (tx, _) = channel(1); + vacant.insert(Arc::new(Mutex::new(tx.clone()))); + let result = self.inner.resolve(input).await; + tx.send(result.as_ref().ok().cloned()).ok(); + self.pending.remove(input); + result.map(Some) + } + } + } +} diff --git a/atrium-common/src/store.rs b/atrium-common/src/store.rs new file mode 100644 index 00000000..d2d8a30a --- /dev/null +++ b/atrium-common/src/store.rs @@ -0,0 +1,19 @@ +pub mod memory; + +use std::error::Error; +use std::future::Future; +use std::hash::Hash; + +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] +pub trait Store +where + K: Eq + Hash, + V: Clone, +{ + type Error: Error; + + fn get(&self, key: &K) -> impl Future, Self::Error>>; + fn set(&self, key: K, value: V) -> impl Future>; + fn del(&self, key: &K) -> impl Future>; + fn clear(&self) -> impl Future>; +} diff --git a/atrium-common/src/store/memory.rs b/atrium-common/src/store/memory.rs new file mode 100644 index 00000000..dc81fd7c --- /dev/null +++ b/atrium-common/src/store/memory.rs @@ -0,0 +1,46 @@ +use super::Store; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; +use std::sync::{Arc, Mutex}; +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("memory store error")] +pub struct Error; + +// TODO: LRU cache? +#[derive(Clone)] +pub struct MemoryStore { + store: Arc>>, +} + +impl Default for MemoryStore { + fn default() -> Self { + Self { store: Arc::new(Mutex::new(HashMap::new())) } + } +} + +impl Store for MemoryStore +where + K: Debug + Eq + Hash + Send + Sync + 'static, + V: Debug + Clone + Send + Sync + 'static, +{ + type Error = Error; + + async fn get(&self, key: &K) -> Result, Self::Error> { + Ok(self.store.lock().unwrap().get(key).cloned()) + } + async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { + self.store.lock().unwrap().insert(key, value); + Ok(()) + } + async fn del(&self, key: &K) -> Result<(), Self::Error> { + self.store.lock().unwrap().remove(key); + Ok(()) + } + async fn clear(&self) -> Result<(), Self::Error> { + self.store.lock().unwrap().clear(); + Ok(()) + } +} diff --git a/atrium-common/src/types.rs b/atrium-common/src/types.rs new file mode 100644 index 00000000..5c9c52ef --- /dev/null +++ b/atrium-common/src/types.rs @@ -0,0 +1,2 @@ +pub mod cached; +pub mod throttled; diff --git a/atrium-common/src/types/cached.rs b/atrium-common/src/types/cached.rs new file mode 100644 index 00000000..e9e19a77 --- /dev/null +++ b/atrium-common/src/types/cached.rs @@ -0,0 +1,34 @@ +pub mod r#impl; + +use std::fmt::Debug; +use std::time::Duration; + +#[derive(Clone, Debug, Default)] +pub struct CacheConfig { + pub max_capacity: Option, + pub time_to_live: Option, +} + +pub trait Cacheable +where + Self: Sized, +{ + fn cached(self, cache: C) -> Cached; +} + +impl Cacheable for T { + fn cached(self, cache: C) -> Cached { + Cached::new(self, cache) + } +} + +pub struct Cached { + pub inner: T, + pub cache: C, +} + +impl Cached { + pub fn new(inner: T, cache: C) -> Self { + Self { inner, cache } + } +} diff --git a/atrium-common/src/types/cached/impl.rs b/atrium-common/src/types/cached/impl.rs new file mode 100644 index 00000000..d7f634f2 --- /dev/null +++ b/atrium-common/src/types/cached/impl.rs @@ -0,0 +1,24 @@ +#[cfg(not(target_arch = "wasm32"))] +mod moka; +#[cfg(target_arch = "wasm32")] +mod wasm; + +use std::future::Future; +use std::hash::Hash; + +#[cfg(not(target_arch = "wasm32"))] +pub use self::moka::MokaCache as CacheImpl; +#[cfg(target_arch = "wasm32")] +pub use self::wasm::WasmCache as CacheImpl; + +use super::CacheConfig; + +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] +pub trait Cache { + type Input: Hash + Eq + Sync + 'static; + type Output: Clone + Sync + 'static; + + fn new(config: CacheConfig) -> Self; + fn get(&self, key: &Self::Input) -> impl Future>; + fn set(&self, key: Self::Input, value: Self::Output) -> impl Future; +} diff --git a/atrium-oauth/identity/src/resolver/cache_impl/moka.rs b/atrium-common/src/types/cached/impl/moka.rs similarity index 88% rename from atrium-oauth/identity/src/resolver/cache_impl/moka.rs rename to atrium-common/src/types/cached/impl/moka.rs index f35fa3a8..cbc2a91d 100644 --- a/atrium-oauth/identity/src/resolver/cache_impl/moka.rs +++ b/atrium-common/src/types/cached/impl/moka.rs @@ -1,4 +1,4 @@ -use super::super::cached_resolver::{Cache as CacheTrait, CachedResolverConfig}; +use super::{Cache as CacheTrait, CacheConfig}; use moka::{future::Cache, policy::EvictionPolicy}; use std::collections::hash_map::RandomState; use std::hash::Hash; @@ -15,7 +15,7 @@ where type Input = I; type Output = O; - fn new(config: CachedResolverConfig) -> Self { + fn new(config: CacheConfig) -> Self { let mut builder = Cache::::builder().eviction_policy(EvictionPolicy::lru()); if let Some(max_capacity) = config.max_capacity { builder = builder.max_capacity(max_capacity); diff --git a/atrium-oauth/identity/src/resolver/cache_impl/wasm.rs b/atrium-common/src/types/cached/impl/wasm.rs similarity index 95% rename from atrium-oauth/identity/src/resolver/cache_impl/wasm.rs rename to atrium-common/src/types/cached/impl/wasm.rs index 8af03932..ba82c48a 100644 --- a/atrium-oauth/identity/src/resolver/cache_impl/wasm.rs +++ b/atrium-common/src/types/cached/impl/wasm.rs @@ -1,4 +1,4 @@ -use super::super::cached_resolver::{Cache as CacheTrait, CachedResolverConfig}; +use super::{Cache as CacheTrait, CacheConfig}; use lru::LruCache; use std::collections::HashMap; use std::hash::Hash; @@ -64,7 +64,7 @@ where type Input = I; type Output = O; - fn new(config: CachedResolverConfig) -> Self { + fn new(config: CacheConfig) -> Self { let store = if let Some(max_capacity) = config.max_capacity { Store::Lru(LruCache::new( NonZeroUsize::new(max_capacity as usize) diff --git a/atrium-common/src/types/throttled.rs b/atrium-common/src/types/throttled.rs new file mode 100644 index 00000000..69fa588e --- /dev/null +++ b/atrium-common/src/types/throttled.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +pub trait Throttleable

+where + Self: std::marker::Sized, +{ + fn throttled(self) -> Throttled; +} + +impl Throttleable

for T +where + P: Default, +{ + fn throttled(self) -> Throttled { + Throttled::new(self) + } +} + +pub struct Throttled { + pub inner: T, + pub pending: Arc

, +} + +impl Throttled +where + P: Default, +{ + pub fn new(inner: T) -> Self { + Self { inner, pending: Arc::new(P::default()) } + } +} diff --git a/atrium-oauth/identity/Cargo.toml b/atrium-oauth/identity/Cargo.toml index 55a0b15b..6273ff77 100644 --- a/atrium-oauth/identity/Cargo.toml +++ b/atrium-oauth/identity/Cargo.toml @@ -15,34 +15,15 @@ keywords = ["atproto", "bluesky", "identity"] [dependencies] atrium-api = { workspace = true, default-features = false } +atrium-common.workspace = true atrium-xrpc.workspace = true -dashmap.workspace = true hickory-proto = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_html_form.workspace = true serde_json.workspace = true thiserror.workspace = true -tokio = { workspace = true, default-features = false, features = ["sync"] } trait-variant.workspace = true -[target.'cfg(not(target_arch = "wasm32"))'.dependencies] -moka = { workspace = true, features = ["future"] } - -[target.'cfg(target_arch = "wasm32")'.dependencies] -lru.workspace = true -web-time.workspace = true - -[dev-dependencies] -futures.workspace = true - -[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] -tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } - -[target.'cfg(target_arch = "wasm32")'.dev-dependencies] -gloo-timers.workspace = true -tokio = { workspace = true, features = ["time"] } -wasm-bindgen-test.workspace = true - [features] default = [] doh-handle-resolver = ["dep:hickory-proto"] diff --git a/atrium-oauth/identity/src/did.rs b/atrium-oauth/identity/src/did.rs index 79621721..0b731cb1 100644 --- a/atrium-oauth/identity/src/did.rs +++ b/atrium-oauth/identity/src/did.rs @@ -2,10 +2,12 @@ mod common_resolver; mod plc_resolver; mod web_resolver; +use crate::Error; + pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig}; pub use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL; -use crate::Resolver; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; +use atrium_common::resolver::Resolver; -pub trait DidResolver: Resolver {} +pub trait DidResolver: Resolver {} diff --git a/atrium-oauth/identity/src/did/common_resolver.rs b/atrium-oauth/identity/src/did/common_resolver.rs index 601127f7..5c18f634 100644 --- a/atrium-oauth/identity/src/did/common_resolver.rs +++ b/atrium-oauth/identity/src/did/common_resolver.rs @@ -1,12 +1,12 @@ use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; +use atrium_common::resolver::Resolver; use atrium_xrpc::HttpClient; use super::plc_resolver::{PlcDidResolver, PlcDidResolverConfig}; use super::web_resolver::{WebDidResolver, WebDidResolverConfig}; use super::DidResolver; use crate::error::{Error, Result}; -use crate::Resolver; use std::sync::Arc; #[derive(Clone, Debug)] @@ -41,6 +41,7 @@ where { type Input = Did; type Output = DidDocument; + type Error = Error; async fn resolve(&self, did: &Self::Input) -> Result { match did.strip_prefix("did:").and_then(|s| s.split_once(':').map(|(method, _)| method)) { diff --git a/atrium-oauth/identity/src/did/plc_resolver.rs b/atrium-oauth/identity/src/did/plc_resolver.rs index 5f8dc1e7..5d32582e 100644 --- a/atrium-oauth/identity/src/did/plc_resolver.rs +++ b/atrium-oauth/identity/src/did/plc_resolver.rs @@ -1,8 +1,8 @@ use super::DidResolver; use crate::error::{Error, Result}; -use crate::Resolver; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; +use atrium_common::resolver::Resolver; use atrium_xrpc::http::uri::Builder; use atrium_xrpc::http::{Request, Uri}; use atrium_xrpc::HttpClient; @@ -33,6 +33,7 @@ where { type Input = Did; type Output = DidDocument; + type Error = Error; async fn resolve(&self, did: &Self::Input) -> Result { let uri = Builder::from(self.plc_directory_url.parse::()?) diff --git a/atrium-oauth/identity/src/did/web_resolver.rs b/atrium-oauth/identity/src/did/web_resolver.rs index 582bdd00..eba6ed99 100644 --- a/atrium-oauth/identity/src/did/web_resolver.rs +++ b/atrium-oauth/identity/src/did/web_resolver.rs @@ -1,8 +1,8 @@ use super::DidResolver; use crate::error::{Error, Result}; -use crate::Resolver; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; +use atrium_common::resolver::Resolver; use atrium_xrpc::http::{header::ACCEPT, Request, Uri}; use atrium_xrpc::HttpClient; use std::sync::Arc; @@ -30,6 +30,7 @@ where { type Input = Did; type Output = DidDocument; + type Error = Error; async fn resolve(&self, did: &Self::Input) -> Result { let document_url = format!( diff --git a/atrium-oauth/identity/src/handle.rs b/atrium-oauth/identity/src/handle.rs index 2ae285dd..da7f0d93 100644 --- a/atrium-oauth/identity/src/handle.rs +++ b/atrium-oauth/identity/src/handle.rs @@ -5,13 +5,15 @@ mod dns_resolver; mod doh_dns_txt_resolver; mod well_known_resolver; +use crate::Error; + pub use self::appview_resolver::{AppViewHandleResolver, AppViewHandleResolverConfig}; pub use self::atproto_resolver::{AtprotoHandleResolver, AtprotoHandleResolverConfig}; pub use self::dns_resolver::DnsTxtResolver; #[cfg(feature = "doh-handle-resolver")] pub use self::doh_dns_txt_resolver::{DohDnsTxtResolver, DohDnsTxtResolverConfig}; pub use self::well_known_resolver::{WellKnownHandleResolver, WellKnownHandleResolverConfig}; -use crate::Resolver; use atrium_api::types::string::{Did, Handle}; +use atrium_common::resolver::Resolver; -pub trait HandleResolver: Resolver {} +pub trait HandleResolver: Resolver {} diff --git a/atrium-oauth/identity/src/handle/appview_resolver.rs b/atrium-oauth/identity/src/handle/appview_resolver.rs index 90255a35..098ab783 100644 --- a/atrium-oauth/identity/src/handle/appview_resolver.rs +++ b/atrium-oauth/identity/src/handle/appview_resolver.rs @@ -1,8 +1,8 @@ use super::HandleResolver; use crate::error::{Error, Result}; -use crate::Resolver; use atrium_api::com::atproto::identity::resolve_handle; use atrium_api::types::string::{Did, Handle}; +use atrium_common::resolver::Resolver; use atrium_xrpc::http::uri::Builder; use atrium_xrpc::http::{Request, Uri}; use atrium_xrpc::HttpClient; @@ -31,6 +31,7 @@ where { type Input = Handle; type Output = Did; + type Error = Error; async fn resolve(&self, handle: &Self::Input) -> Result { let uri = Builder::from(self.service_url.parse::()?) diff --git a/atrium-oauth/identity/src/handle/atproto_resolver.rs b/atrium-oauth/identity/src/handle/atproto_resolver.rs index 25ec54b4..98579f81 100644 --- a/atrium-oauth/identity/src/handle/atproto_resolver.rs +++ b/atrium-oauth/identity/src/handle/atproto_resolver.rs @@ -2,8 +2,9 @@ use super::dns_resolver::{DnsHandleResolver, DnsHandleResolverConfig, DnsTxtReso use super::well_known_resolver::{WellKnownHandleResolver, WellKnownHandleResolverConfig}; use super::HandleResolver; use crate::error::Result; -use crate::Resolver; +use crate::Error; use atrium_api::types::string::{Did, Handle}; +use atrium_common::resolver::Resolver; use atrium_xrpc::HttpClient; use std::sync::Arc; @@ -38,6 +39,7 @@ where { type Input = Handle; type Output = Did; + type Error = Error; async fn resolve(&self, handle: &Self::Input) -> Result { let d_fut = self.dns.resolve(handle); diff --git a/atrium-oauth/identity/src/handle/dns_resolver.rs b/atrium-oauth/identity/src/handle/dns_resolver.rs index 7cdc6a92..984254b5 100644 --- a/atrium-oauth/identity/src/handle/dns_resolver.rs +++ b/atrium-oauth/identity/src/handle/dns_resolver.rs @@ -1,7 +1,7 @@ use super::HandleResolver; use crate::error::{Error, Result}; -use crate::Resolver; use atrium_api::types::string::{Did, Handle}; +use atrium_common::resolver::Resolver; use std::future::Future; const SUBDOMAIN: &str = "_atproto"; @@ -41,6 +41,7 @@ where { type Input = Handle; type Output = Did; + type Error = Error; async fn resolve(&self, handle: &Self::Input) -> Result { for result in self diff --git a/atrium-oauth/identity/src/handle/well_known_resolver.rs b/atrium-oauth/identity/src/handle/well_known_resolver.rs index 9f04b2b7..e3542b31 100644 --- a/atrium-oauth/identity/src/handle/well_known_resolver.rs +++ b/atrium-oauth/identity/src/handle/well_known_resolver.rs @@ -1,7 +1,7 @@ use super::HandleResolver; use crate::error::{Error, Result}; -use crate::Resolver; use atrium_api::types::string::{Did, Handle}; +use atrium_common::resolver::Resolver; use atrium_xrpc::http::Request; use atrium_xrpc::HttpClient; use std::sync::Arc; @@ -29,6 +29,7 @@ where { type Input = Handle; type Output = Did; + type Error = Error; async fn resolve(&self, handle: &Self::Input) -> Result { let url = format!("https://{}{WELL_KNWON_PATH}", handle.as_str()); diff --git a/atrium-oauth/identity/src/identity_resolver.rs b/atrium-oauth/identity/src/identity_resolver.rs index e8244bce..a70e1856 100644 --- a/atrium-oauth/identity/src/identity_resolver.rs +++ b/atrium-oauth/identity/src/identity_resolver.rs @@ -1,6 +1,7 @@ use crate::error::{Error, Result}; -use crate::{did::DidResolver, handle::HandleResolver, Resolver}; +use crate::{did::DidResolver, handle::HandleResolver}; use atrium_api::types::string::AtIdentifier; +use atrium_common::resolver::Resolver; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -33,6 +34,7 @@ where { type Input = str; type Output = ResolvedIdentity; + type Error = Error; async fn resolve(&self, input: &Self::Input) -> Result { let document = diff --git a/atrium-oauth/identity/src/lib.rs b/atrium-oauth/identity/src/lib.rs index d64a61cf..9f397322 100644 --- a/atrium-oauth/identity/src/lib.rs +++ b/atrium-oauth/identity/src/lib.rs @@ -5,4 +5,3 @@ pub mod identity_resolver; pub mod resolver; pub use self::error::{Error, Result}; -pub use self::resolver::Resolver; diff --git a/atrium-oauth/identity/src/resolver.rs b/atrium-oauth/identity/src/resolver.rs index 5cfdff90..3f3d911d 100644 --- a/atrium-oauth/identity/src/resolver.rs +++ b/atrium-oauth/identity/src/resolver.rs @@ -1,247 +1 @@ -mod cache_impl; -mod cached_resolver; -mod throttled_resolver; - -pub use self::cached_resolver::{CachedResolver, CachedResolverConfig}; -pub use self::throttled_resolver::ThrottledResolver; pub use crate::error::Result; -use std::future::Future; -use std::hash::Hash; - -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait Resolver { - type Input: ?Sized; - type Output; - - fn resolve(&self, input: &Self::Input) -> impl Future>; -} - -pub trait Cacheable -where - Self: Sized + Resolver, - Self::Input: Sized, -{ - fn cached(self, config: CachedResolverConfig) -> CachedResolver; -} - -impl Cacheable for R -where - R: Sized + Resolver, - R::Input: Sized + Hash + Eq + Send + Sync + 'static, - R::Output: Clone + Send + Sync + 'static, -{ - fn cached(self, config: CachedResolverConfig) -> CachedResolver { - CachedResolver::new(self, config) - } -} - -pub trait Throttleable -where - Self: Sized + Resolver, - Self::Input: Sized, -{ - fn throttled(self) -> ThrottledResolver; -} - -impl Throttleable for R -where - R: Sized + Resolver, - R::Input: Clone + Hash + Eq + Send + Sync + 'static, - R::Output: Clone + Send + Sync + 'static, -{ - fn throttled(self) -> ThrottledResolver { - ThrottledResolver::new(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::error::Error; - use std::collections::HashMap; - use std::sync::Arc; - use std::time::Duration; - use tokio::sync::RwLock; - #[cfg(target_arch = "wasm32")] - use wasm_bindgen_test::wasm_bindgen_test; - - #[cfg(not(target_arch = "wasm32"))] - async fn sleep(duration: Duration) { - tokio::time::sleep(duration).await; - } - - #[cfg(target_arch = "wasm32")] - async fn sleep(duration: Duration) { - gloo_timers::future::sleep(duration).await; - } - - struct MockResolver { - data: HashMap, - counts: Arc>>, - } - - impl Resolver for MockResolver { - type Input = String; - type Output = String; - - async fn resolve(&self, input: &Self::Input) -> Result { - sleep(Duration::from_millis(10)).await; - *self.counts.write().await.entry(input.clone()).or_default() += 1; - if let Some(value) = self.data.get(input) { - Ok(value.clone()) - } else { - Err(Error::NotFound) - } - } - } - - fn mock_resolver(counts: Arc>>) -> MockResolver { - MockResolver { - data: [ - (String::from("k1"), String::from("v1")), - (String::from("k2"), String::from("v2")), - ] - .into_iter() - .collect(), - counts, - } - } - - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - async fn test_no_cached() { - let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = mock_resolver(counts.clone()); - for (input, expected) in [ - ("k1", Some("v1")), - ("k2", Some("v2")), - ("k2", Some("v2")), - ("k1", Some("v1")), - ("k3", None), - ("k1", Some("v1")), - ("k3", None), - ] { - let result = resolver.resolve(&input.to_string()).await; - match expected { - Some(value) => assert_eq!(result.expect("failed to resolve"), value), - None => assert!(result.is_err()), - } - } - assert_eq!( - *counts.read().await, - [(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),] - .into_iter() - .collect() - ); - } - - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - async fn test_cached() { - let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = mock_resolver(counts.clone()).cached(Default::default()); - for (input, expected) in [ - ("k1", Some("v1")), - ("k2", Some("v2")), - ("k2", Some("v2")), - ("k1", Some("v1")), - ("k3", None), - ("k1", Some("v1")), - ("k3", None), - ] { - let result = resolver.resolve(&input.to_string()).await; - match expected { - Some(value) => assert_eq!(result.expect("failed to resolve"), value), - None => assert!(result.is_err()), - } - } - assert_eq!( - *counts.read().await, - [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),] - .into_iter() - .collect() - ); - } - - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_cached_with_max_capacity() { - let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = mock_resolver(counts.clone()) - .cached(CachedResolverConfig { max_capacity: Some(1), ..Default::default() }); - for (input, expected) in [ - ("k1", Some("v1")), - ("k2", Some("v2")), - ("k2", Some("v2")), - ("k1", Some("v1")), - ("k3", None), - ("k1", Some("v1")), - ("k3", None), - ] { - let result = resolver.resolve(&input.to_string()).await; - match expected { - Some(value) => assert_eq!(result.expect("failed to resolve"), value), - None => assert!(result.is_err()), - } - } - assert_eq!( - *counts.read().await, - [(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),] - .into_iter() - .collect() - ); - } - - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_cached_with_time_to_live() { - let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = mock_resolver(counts.clone()).cached(CachedResolverConfig { - time_to_live: Some(Duration::from_millis(10)), - ..Default::default() - }); - for _ in 0..10 { - let result = resolver.resolve(&String::from("k1")).await; - assert_eq!(result.expect("failed to resolve"), "v1"); - } - sleep(Duration::from_millis(10)).await; - for _ in 0..10 { - let result = resolver.resolve(&String::from("k1")).await; - assert_eq!(result.expect("failed to resolve"), "v1"); - } - assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect()); - } - - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] - async fn test_throttled() { - let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = Arc::new(mock_resolver(counts.clone()).throttled()); - - let mut handles = Vec::new(); - for (input, expected) in [ - ("k1", Some("v1")), - ("k2", Some("v2")), - ("k2", Some("v2")), - ("k1", Some("v1")), - ("k3", None), - ("k1", Some("v1")), - ("k3", None), - ] { - let resolver = resolver.clone(); - handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) }); - } - for (result, expected) in futures::future::join_all(handles).await { - match expected { - Some(value) => assert_eq!(result.expect("failed to resolve"), value), - None => assert!(result.is_err()), - } - } - assert_eq!( - *counts.read().await, - [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),] - .into_iter() - .collect() - ); - } -} diff --git a/atrium-oauth/identity/src/resolver/cache_impl.rs b/atrium-oauth/identity/src/resolver/cache_impl.rs deleted file mode 100644 index c1b72c9c..00000000 --- a/atrium-oauth/identity/src/resolver/cache_impl.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[cfg(not(target_arch = "wasm32"))] -mod moka; -#[cfg(target_arch = "wasm32")] -mod wasm; - -#[cfg(not(target_arch = "wasm32"))] -pub use self::moka::MokaCache as CacheImpl; -#[cfg(target_arch = "wasm32")] -pub use self::wasm::WasmCache as CacheImpl; diff --git a/atrium-oauth/identity/src/resolver/cached_resolver.rs b/atrium-oauth/identity/src/resolver/cached_resolver.rs deleted file mode 100644 index 79a38295..00000000 --- a/atrium-oauth/identity/src/resolver/cached_resolver.rs +++ /dev/null @@ -1,61 +0,0 @@ -use super::cache_impl::CacheImpl; -use crate::error::Result; -use crate::Resolver; -use std::fmt::Debug; -use std::hash::Hash; -use std::time::Duration; - -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub(crate) trait Cache { - type Input: Hash + Eq + Sync + 'static; - type Output: Clone + Sync + 'static; - - fn new(config: CachedResolverConfig) -> Self; - async fn get(&self, key: &Self::Input) -> Option; - async fn set(&self, key: Self::Input, value: Self::Output); -} - -#[derive(Clone, Debug, Default)] -pub struct CachedResolverConfig { - pub max_capacity: Option, - pub time_to_live: Option, -} - -pub struct CachedResolver -where - R: Resolver, - R::Input: Sized, -{ - resolver: R, - cache: CacheImpl, -} - -impl CachedResolver -where - R: Resolver, - R::Input: Sized + Hash + Eq + Send + Sync + 'static, - R::Output: Clone + Send + Sync + 'static, -{ - pub fn new(resolver: R, config: CachedResolverConfig) -> Self { - Self { resolver, cache: CacheImpl::new(config) } - } -} - -impl Resolver for CachedResolver -where - R: Resolver + Send + Sync + 'static, - R::Input: Clone + Hash + Eq + Send + Sync + 'static + Debug, - R::Output: Clone + Send + Sync + 'static, -{ - type Input = R::Input; - type Output = R::Output; - - async fn resolve(&self, input: &Self::Input) -> Result { - if let Some(output) = self.cache.get(input).await { - return Ok(output); - } - let output = self.resolver.resolve(input).await?; - self.cache.set(input.clone(), output.clone()).await; - Ok(output) - } -} diff --git a/atrium-oauth/identity/src/resolver/throttled_resolver.rs b/atrium-oauth/identity/src/resolver/throttled_resolver.rs deleted file mode 100644 index 195473f0..00000000 --- a/atrium-oauth/identity/src/resolver/throttled_resolver.rs +++ /dev/null @@ -1,59 +0,0 @@ -use super::Resolver; -use crate::error::{Error, Result}; -use dashmap::{DashMap, Entry}; -use std::hash::Hash; -use std::sync::Arc; -use tokio::sync::broadcast::{channel, Sender}; -use tokio::sync::Mutex; - -type SharedSender = Arc>>>; - -pub struct ThrottledResolver -where - R: Resolver, - R::Input: Sized, -{ - resolver: R, - senders: Arc>>, -} - -impl ThrottledResolver -where - R: Resolver, - R::Input: Clone + Hash + Eq + Send + Sync + 'static, -{ - pub fn new(resolver: R) -> Self { - Self { resolver, senders: Arc::new(DashMap::new()) } - } -} - -impl Resolver for ThrottledResolver -where - R: Resolver + Send + Sync + 'static, - R::Input: Clone + Hash + Eq + Send + Sync + 'static, - R::Output: Clone + Send + Sync + 'static, -{ - type Input = R::Input; - type Output = R::Output; - - async fn resolve(&self, input: &Self::Input) -> Result { - match self.senders.entry(input.clone()) { - Entry::Occupied(occupied) => { - let tx = occupied.get().lock().await.clone(); - drop(occupied); - match tx.subscribe().recv().await.expect("recv") { - Some(result) => Ok(result), - None => Err(Error::NotFound), - } - } - Entry::Vacant(vacant) => { - let (tx, _) = channel(1); - vacant.insert(Arc::new(Mutex::new(tx.clone()))); - let result = self.resolver.resolve(input).await; - tx.send(result.as_ref().ok().cloned()).ok(); - self.senders.remove(input); - result - } - } - } -} diff --git a/atrium-oauth/oauth-client/Cargo.toml b/atrium-oauth/oauth-client/Cargo.toml index 99a0f3db..02596f59 100644 --- a/atrium-oauth/oauth-client/Cargo.toml +++ b/atrium-oauth/oauth-client/Cargo.toml @@ -15,6 +15,7 @@ keywords = ["atproto", "bluesky", "oauth"] [dependencies] atrium-api = { workspace = true, default-features = false } +atrium-common.workspace = true atrium-identity.workspace = true atrium-xrpc.workspace = true base64.workspace = true diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index ca1534a3..25e21b43 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -11,7 +11,8 @@ use crate::types::{ TryIntoOAuthClientMetadata, }; use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values}; -use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver}; +use atrium_common::resolver::Resolver; +use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_xrpc::HttpClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; diff --git a/atrium-oauth/oauth-client/src/resolver.rs b/atrium-oauth/oauth-client/src/resolver.rs index ad36e813..d75f7abe 100644 --- a/atrium-oauth/oauth-client/src/resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver.rs @@ -1,3 +1,11 @@ +use atrium_common::resolver::CachedResolver; +use atrium_common::resolver::Resolver; +use atrium_common::resolver::ThrottledResolver; +use atrium_common::types::cached::r#impl::Cache; +use atrium_common::types::cached::r#impl::CacheImpl; +use atrium_common::types::cached::CacheConfig; +use atrium_common::types::cached::Cacheable; +use atrium_common::types::throttled::Throttleable; mod oauth_authorization_server_resolver; mod oauth_protected_resource_resolver; @@ -7,10 +15,7 @@ use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetad use atrium_identity::identity_resolver::{ IdentityResolver, IdentityResolverConfig, ResolvedIdentity, }; -use atrium_identity::resolver::{ - Cacheable, CachedResolver, CachedResolverConfig, Throttleable, ThrottledResolver, -}; -use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver}; +use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_identity::{Error, Result}; use atrium_xrpc::HttpClient; use std::marker::PhantomData; @@ -19,13 +24,13 @@ use std::time::Duration; #[derive(Clone, Debug)] pub struct OAuthAuthorizationServerMetadataResolverConfig { - pub cache: CachedResolverConfig, + pub cache: CacheConfig, } impl Default for OAuthAuthorizationServerMetadataResolverConfig { fn default() -> Self { Self { - cache: CachedResolverConfig { + cache: CacheConfig { max_capacity: Some(100), time_to_live: Some(Duration::from_secs(60)), }, @@ -35,13 +40,13 @@ impl Default for OAuthAuthorizationServerMetadataResolverConfig { #[derive(Clone, Debug)] pub struct OAuthProtectedResourceMetadataResolverConfig { - pub cache: CachedResolverConfig, + pub cache: CacheConfig, } impl Default for OAuthProtectedResourceMetadataResolverConfig { fn default() -> Self { Self { - cache: CachedResolverConfig { + cache: CacheConfig { max_capacity: Some(100), time_to_live: Some(Duration::from_secs(60)), }, @@ -81,11 +86,11 @@ where let protected_resource_resolver = DefaultOAuthProtectedResourceResolver::new(http_client.clone()) .throttled() - .cached(config.authorization_server_metadata.cache); + .cached(CacheImpl::new(config.authorization_server_metadata.cache)); let authorization_server_resolver = DefaultOAuthAuthorizationServerResolver::new(http_client.clone()) .throttled() - .cached(config.protected_resource_metadata.cache); + .cached(CacheImpl::new(config.protected_resource_metadata.cache)); Self { identity_resolver: IdentityResolver::new(IdentityResolverConfig { did_resolver: config.did_resolver, @@ -108,7 +113,9 @@ where &self, issuer: impl AsRef, ) -> Result { - self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await + let result = + self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await?; + result.ok_or_else(|| Error::NotFound) } async fn resolve_from_service(&self, input: &str) -> Result { // Assume first that input is a PDS URL (as required by ATPROTO) @@ -130,7 +137,8 @@ where &self, pds: &str, ) -> Result { - let rs_metadata = self.protected_resource_resolver.resolve(&pds.to_string()).await?; + let result = self.protected_resource_resolver.resolve(&pds.to_string()).await?; + let rs_metadata = result.ok_or_else(|| Error::NotFound)?; // ATPROTO requires one, and only one, authorization server entry // > That document MUST contain a single item in the authorization_servers array. // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata @@ -182,6 +190,7 @@ where { type Input = str; type Output = (OAuthAuthorizationServerMetadata, Option); + type Error = Error; async fn resolve(&self, input: &Self::Input) -> Result { // Allow using an entryway, or PDS url, directly as login input (e.g. diff --git a/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs b/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs index e38428fe..fd06f3a4 100644 --- a/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs @@ -1,5 +1,6 @@ +use crate::resolver::Resolver; use crate::types::OAuthAuthorizationServerMetadata; -use atrium_identity::{Error, Resolver, Result}; +use atrium_identity::{Error, Result}; use atrium_xrpc::http::uri::Builder; use atrium_xrpc::http::{Request, StatusCode, Uri}; use atrium_xrpc::HttpClient; @@ -21,6 +22,7 @@ where { type Input = String; type Output = OAuthAuthorizationServerMetadata; + type Error = Error; async fn resolve(&self, issuer: &Self::Input) -> Result { let uri = Builder::from(issuer.parse::()?) diff --git a/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs b/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs index 98c2ea7a..9aecdfed 100644 --- a/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs @@ -1,5 +1,6 @@ use crate::types::OAuthProtectedResourceMetadata; -use atrium_identity::{Error, Resolver, Result}; +use atrium_common::resolver::Resolver; +use atrium_identity::{Error, Result}; use atrium_xrpc::http::uri::Builder; use atrium_xrpc::http::{Request, StatusCode, Uri}; use atrium_xrpc::HttpClient; @@ -21,6 +22,7 @@ where { type Input = String; type Output = OAuthProtectedResourceMetadata; + type Error = Error; async fn resolve(&self, resource: &Self::Input) -> Result { let uri = Builder::from(resource.parse::()?)