-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Extract common abstractions (#244)
* initialize crate * add `Cached` trait * add `Throttled` trait * add resolvers * add store * add workflows * fix `atrium-oauth` * add error conversions * change type visibility * fix identity crate * fix oauth-client crate * small fix * mofify crate authors * change `Resolver` type signature * apply suggestions * fix `Throttled` tests * fix wasm `CacheTrait`
- Loading branch information
Showing
37 changed files
with
574 additions
and
427 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
name: Common | ||
on: | ||
push: | ||
branches: ["main"] | ||
pull_request: | ||
branches: ["main"] | ||
env: | ||
CARGO_TERM_COLOR: always | ||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Build | ||
run: | | ||
cargo build -p atrium-common --verbose | ||
- name: Run tests | ||
run: | | ||
cargo test -p atrium-common --lib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
[package] | ||
name = "atrium-common" | ||
version = "0.1.0" | ||
authors = ["sugyan <[email protected]>", "avdb13 <[email protected]>"] | ||
edition.workspace = true | ||
rust-version.workspace = true | ||
description = "Utility library for common abstractions in atproto" | ||
documentation = "https://docs.rs/atrium-common" | ||
readme = "README.md" | ||
repository.workspace = true | ||
license.workspace = true | ||
keywords = ["atproto", "bluesky"] | ||
|
||
[dependencies] | ||
dashmap.workspace = true | ||
thiserror.workspace = true | ||
tokio = { workspace = true, default-features = false, features = ["sync"] } | ||
trait-variant.workspace = true | ||
|
||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] | ||
moka = { workspace = true, features = ["future"] } | ||
|
||
[target.'cfg(target_arch = "wasm32")'.dependencies] | ||
lru.workspace = true | ||
web-time.workspace = true | ||
|
||
[dev-dependencies] | ||
futures.workspace = true | ||
|
||
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] | ||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } | ||
|
||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies] | ||
gloo-timers.workspace = true | ||
tokio = { workspace = true, features = ["time"] } | ||
wasm-bindgen-test.workspace = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pub mod resolver; | ||
pub mod store; | ||
pub mod types; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
mod cached; | ||
mod throttled; | ||
|
||
pub use self::cached::CachedResolver; | ||
pub use self::throttled::ThrottledResolver; | ||
use std::future::Future; | ||
|
||
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] | ||
pub trait Resolver { | ||
type Input: ?Sized; | ||
type Output; | ||
type Error; | ||
|
||
fn resolve( | ||
&self, | ||
input: &Self::Input, | ||
) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>; | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::types::cached::r#impl::{Cache, CacheImpl}; | ||
use crate::types::cached::{CacheConfig, Cacheable}; | ||
use crate::types::throttled::Throttleable; | ||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
use std::time::Duration; | ||
use tokio::sync::RwLock; | ||
#[cfg(target_arch = "wasm32")] | ||
use wasm_bindgen_test::wasm_bindgen_test; | ||
|
||
#[cfg(not(target_arch = "wasm32"))] | ||
async fn sleep(duration: Duration) { | ||
tokio::time::sleep(duration).await; | ||
} | ||
|
||
#[cfg(target_arch = "wasm32")] | ||
async fn sleep(duration: Duration) { | ||
gloo_timers::future::sleep(duration).await; | ||
} | ||
|
||
#[derive(Debug, PartialEq)] | ||
struct Error; | ||
|
||
type Result<T> = core::result::Result<T, Error>; | ||
|
||
struct MockResolver { | ||
data: HashMap<String, String>, | ||
counts: Arc<RwLock<HashMap<String, usize>>>, | ||
} | ||
|
||
impl Resolver for MockResolver { | ||
type Input = String; | ||
type Output = String; | ||
type Error = Error; | ||
|
||
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> { | ||
sleep(Duration::from_millis(10)).await; | ||
*self.counts.write().await.entry(input.clone()).or_default() += 1; | ||
if let Some(value) = self.data.get(input) { | ||
Ok(value.clone()) | ||
} else { | ||
Err(Error) | ||
} | ||
} | ||
} | ||
|
||
fn mock_resolver(counts: Arc<RwLock<HashMap<String, usize>>>) -> MockResolver { | ||
MockResolver { | ||
data: [ | ||
(String::from("k1"), String::from("v1")), | ||
(String::from("k2"), String::from("v2")), | ||
] | ||
.into_iter() | ||
.collect(), | ||
counts, | ||
} | ||
} | ||
|
||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] | ||
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] | ||
async fn test_no_cached() { | ||
let counts = Arc::new(RwLock::new(HashMap::new())); | ||
let resolver = mock_resolver(counts.clone()); | ||
for (input, expected) in [ | ||
("k1", Some("v1")), | ||
("k2", Some("v2")), | ||
("k2", Some("v2")), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
] { | ||
let result = resolver.resolve(&input.to_string()).await; | ||
match expected { | ||
Some(value) => assert_eq!(result.expect("failed to resolve"), value), | ||
None => assert_eq!(result.expect_err("succesfully resolved"), Error), | ||
} | ||
} | ||
assert_eq!( | ||
*counts.read().await, | ||
[(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),] | ||
.into_iter() | ||
.collect() | ||
); | ||
} | ||
|
||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] | ||
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] | ||
async fn test_cached() { | ||
let counts = Arc::new(RwLock::new(HashMap::new())); | ||
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default())); | ||
for (input, expected) in [ | ||
("k1", Some("v1")), | ||
("k2", Some("v2")), | ||
("k2", Some("v2")), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
] { | ||
let result = resolver.resolve(&input.to_string()).await; | ||
match expected { | ||
Some(value) => assert_eq!(result.expect("failed to resolve"), value), | ||
None => assert_eq!(result.expect_err("succesfully resolved"), Error), | ||
} | ||
} | ||
assert_eq!( | ||
*counts.read().await, | ||
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),] | ||
.into_iter() | ||
.collect() | ||
); | ||
} | ||
|
||
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] | ||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] | ||
async fn test_cached_with_max_capacity() { | ||
let counts = Arc::new(RwLock::new(HashMap::new())); | ||
let resolver = mock_resolver(counts.clone()) | ||
.cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() })); | ||
for (input, expected) in [ | ||
("k1", Some("v1")), | ||
("k2", Some("v2")), | ||
("k2", Some("v2")), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
] { | ||
let result = resolver.resolve(&input.to_string()).await; | ||
match expected { | ||
Some(value) => assert_eq!(result.expect("failed to resolve"), value), | ||
None => assert_eq!(result.expect_err("succesfully resolved"), Error), | ||
} | ||
} | ||
assert_eq!( | ||
*counts.read().await, | ||
[(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),] | ||
.into_iter() | ||
.collect() | ||
); | ||
} | ||
|
||
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] | ||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] | ||
async fn test_cached_with_time_to_live() { | ||
let counts = Arc::new(RwLock::new(HashMap::new())); | ||
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig { | ||
time_to_live: Some(Duration::from_millis(10)), | ||
..Default::default() | ||
})); | ||
for _ in 0..10 { | ||
let result = resolver.resolve(&String::from("k1")).await; | ||
assert_eq!(result.expect("failed to resolve"), "v1"); | ||
} | ||
sleep(Duration::from_millis(10)).await; | ||
for _ in 0..10 { | ||
let result = resolver.resolve(&String::from("k1")).await; | ||
assert_eq!(result.expect("failed to resolve"), "v1"); | ||
} | ||
assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect()); | ||
} | ||
|
||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] | ||
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] | ||
async fn test_throttled() { | ||
let counts = Arc::new(RwLock::new(HashMap::new())); | ||
let resolver = Arc::new(mock_resolver(counts.clone()).throttled()); | ||
|
||
let mut handles = Vec::new(); | ||
for (input, expected) in [ | ||
("k1", Some("v1")), | ||
("k2", Some("v2")), | ||
("k2", Some("v2")), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
("k1", Some("v1")), | ||
("k3", None), | ||
] { | ||
let resolver = resolver.clone(); | ||
handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) }); | ||
} | ||
for (result, expected) in futures::future::join_all(handles).await { | ||
let result = result.and_then(|opt| opt.ok_or(Error)); | ||
|
||
match expected { | ||
Some(value) => { | ||
assert_eq!(result.expect("failed to resolve"), value) | ||
} | ||
None => assert_eq!(result.expect_err("succesfully resolved"), Error), | ||
} | ||
} | ||
assert_eq!( | ||
*counts.read().await, | ||
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),] | ||
.into_iter() | ||
.collect() | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
use std::hash::Hash; | ||
|
||
use crate::types::cached::r#impl::{Cache, CacheImpl}; | ||
use crate::types::cached::Cached; | ||
|
||
use super::Resolver; | ||
|
||
pub type CachedResolver<R> = Cached<R, CacheImpl<<R as Resolver>::Input, <R as Resolver>::Output>>; | ||
|
||
impl<R, C> Resolver for Cached<R, C> | ||
where | ||
R: Resolver + Send + Sync + 'static, | ||
R::Input: Clone + Hash + Eq + Send + Sync + 'static, | ||
R::Output: Clone + Send + Sync + 'static, | ||
C: Cache<Input = R::Input, Output = R::Output> + Send + Sync + 'static, | ||
C::Input: Clone + Hash + Eq + Send + Sync + 'static, | ||
C::Output: Clone + Send + Sync + 'static, | ||
{ | ||
type Input = R::Input; | ||
type Output = R::Output; | ||
type Error = R::Error; | ||
|
||
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> { | ||
if let Some(output) = self.cache.get(input).await { | ||
return Ok(output); | ||
} | ||
let output = self.inner.resolve(input).await?; | ||
self.cache.set(input.clone(), output.clone()).await; | ||
Ok(output) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
use std::{hash::Hash, sync::Arc}; | ||
|
||
use dashmap::{DashMap, Entry}; | ||
use tokio::sync::broadcast::{channel, Sender}; | ||
use tokio::sync::Mutex; | ||
|
||
use crate::types::throttled::Throttled; | ||
|
||
use super::Resolver; | ||
|
||
pub type SenderMap<R> = | ||
DashMap<<R as Resolver>::Input, Arc<Mutex<Sender<Option<<R as Resolver>::Output>>>>>; | ||
|
||
pub type ThrottledResolver<R> = Throttled<R, SenderMap<R>>; | ||
|
||
impl<R> Resolver for Throttled<R, SenderMap<R>> | ||
where | ||
R: Resolver + Send + Sync + 'static, | ||
R::Input: Clone + Hash + Eq + Send + Sync + 'static, | ||
R::Output: Clone + Send + Sync + 'static, | ||
{ | ||
type Input = R::Input; | ||
type Output = Option<R::Output>; | ||
type Error = R::Error; | ||
|
||
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> { | ||
match self.pending.entry(input.clone()) { | ||
Entry::Occupied(occupied) => { | ||
let tx = occupied.get().lock().await.clone(); | ||
drop(occupied); | ||
Ok(tx.subscribe().recv().await.expect("recv")) | ||
} | ||
Entry::Vacant(vacant) => { | ||
let (tx, _) = channel(1); | ||
vacant.insert(Arc::new(Mutex::new(tx.clone()))); | ||
let result = self.inner.resolve(input).await; | ||
tx.send(result.as_ref().ok().cloned()).ok(); | ||
self.pending.remove(input); | ||
result.map(Some) | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.