Skip to content

Commit

Permalink
feat: Extract common abstractions (#244)
Browse files Browse the repository at this point in the history
* initialize crate

* add `Cached` trait

* add `Throttled` trait

* add resolvers

* add store

* add workflows

* fix `atrium-oauth`

* add error conversions

* change type visibility

* fix identity crate

* fix oauth-client crate

* small fix

* mofify crate authors

* change `Resolver` type signature

* apply suggestions

* fix `Throttled` tests

* fix wasm `CacheTrait`
  • Loading branch information
avdb13 authored Nov 21, 2024
1 parent ac1ad3f commit de18c4f
Show file tree
Hide file tree
Showing 37 changed files with 574 additions and 427 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/common.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Common
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
env:
CARGO_TERM_COLOR: always
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build
run: |
cargo build -p atrium-common --verbose
- name: Run tests
run: |
cargo test -p atrium-common --lib
1 change: 1 addition & 0 deletions .github/workflows/wasm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ jobs:
- run: wasm-pack test --node atrium-xrpc
- run: wasm-pack test --node atrium-xrpc-client
- run: wasm-pack test --node atrium-oauth/identity
- run: wasm-pack test --node atrium-common
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = [
"atrium-api",
"atrium-common",
"atrium-crypto",
"atrium-xrpc",
"atrium-xrpc-client",
Expand All @@ -26,6 +27,7 @@ keywords = ["atproto", "bluesky"]
[workspace.dependencies]
# Intra-workspace dependencies
atrium-api = { version = "0.24.8", path = "atrium-api", default-features = false }
atrium-common = { version = "0.1.0", path = "atrium-common" }
atrium-identity = { version = "0.1.0", path = "atrium-oauth/identity" }
atrium-xrpc = { version = "0.12.0", path = "atrium-xrpc" }
atrium-xrpc-client = { version = "0.5.10", path = "atrium-xrpc-client" }
Expand Down
36 changes: 36 additions & 0 deletions atrium-common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[package]
name = "atrium-common"
version = "0.1.0"
authors = ["sugyan <[email protected]>", "avdb13 <[email protected]>"]
edition.workspace = true
rust-version.workspace = true
description = "Utility library for common abstractions in atproto"
documentation = "https://docs.rs/atrium-common"
readme = "README.md"
repository.workspace = true
license.workspace = true
keywords = ["atproto", "bluesky"]

[dependencies]
dashmap.workspace = true
thiserror.workspace = true
tokio = { workspace = true, default-features = false, features = ["sync"] }
trait-variant.workspace = true

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
moka = { workspace = true, features = ["future"] }

[target.'cfg(target_arch = "wasm32")'.dependencies]
lru.workspace = true
web-time.workspace = true

[dev-dependencies]
futures.workspace = true

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
gloo-timers.workspace = true
tokio = { workspace = true, features = ["time"] }
wasm-bindgen-test.workspace = true
3 changes: 3 additions & 0 deletions atrium-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod resolver;
pub mod store;
pub mod types;
222 changes: 222 additions & 0 deletions atrium-common/src/resolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
mod cached;
mod throttled;

pub use self::cached::CachedResolver;
pub use self::throttled::ThrottledResolver;
use std::future::Future;

#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait Resolver {
type Input: ?Sized;
type Output;
type Error;

fn resolve(
&self,
input: &Self::Input,
) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>;
}

#[cfg(test)]
mod tests {
use super::*;
use crate::types::cached::r#impl::{Cache, CacheImpl};
use crate::types::cached::{CacheConfig, Cacheable};
use crate::types::throttled::Throttleable;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;

#[cfg(not(target_arch = "wasm32"))]
async fn sleep(duration: Duration) {
tokio::time::sleep(duration).await;
}

#[cfg(target_arch = "wasm32")]
async fn sleep(duration: Duration) {
gloo_timers::future::sleep(duration).await;
}

#[derive(Debug, PartialEq)]
struct Error;

type Result<T> = core::result::Result<T, Error>;

struct MockResolver {
data: HashMap<String, String>,
counts: Arc<RwLock<HashMap<String, usize>>>,
}

impl Resolver for MockResolver {
type Input = String;
type Output = String;
type Error = Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
sleep(Duration::from_millis(10)).await;
*self.counts.write().await.entry(input.clone()).or_default() += 1;
if let Some(value) = self.data.get(input) {
Ok(value.clone())
} else {
Err(Error)
}
}
}

fn mock_resolver(counts: Arc<RwLock<HashMap<String, usize>>>) -> MockResolver {
MockResolver {
data: [
(String::from("k1"), String::from("v1")),
(String::from("k2"), String::from("v2")),
]
.into_iter()
.collect(),
counts,
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_no_cached() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone());
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),]
.into_iter()
.collect()
);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_cached() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default()));
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),]
.into_iter()
.collect()
);
}

#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_cached_with_max_capacity() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone())
.cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() }));
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let result = resolver.resolve(&input.to_string()).await;
match expected {
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),]
.into_iter()
.collect()
);
}

#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_cached_with_time_to_live() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig {
time_to_live: Some(Duration::from_millis(10)),
..Default::default()
}));
for _ in 0..10 {
let result = resolver.resolve(&String::from("k1")).await;
assert_eq!(result.expect("failed to resolve"), "v1");
}
sleep(Duration::from_millis(10)).await;
for _ in 0..10 {
let result = resolver.resolve(&String::from("k1")).await;
assert_eq!(result.expect("failed to resolve"), "v1");
}
assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect());
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_throttled() {
let counts = Arc::new(RwLock::new(HashMap::new()));
let resolver = Arc::new(mock_resolver(counts.clone()).throttled());

let mut handles = Vec::new();
for (input, expected) in [
("k1", Some("v1")),
("k2", Some("v2")),
("k2", Some("v2")),
("k1", Some("v1")),
("k3", None),
("k1", Some("v1")),
("k3", None),
] {
let resolver = resolver.clone();
handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) });
}
for (result, expected) in futures::future::join_all(handles).await {
let result = result.and_then(|opt| opt.ok_or(Error));

match expected {
Some(value) => {
assert_eq!(result.expect("failed to resolve"), value)
}
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
}
}
assert_eq!(
*counts.read().await,
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),]
.into_iter()
.collect()
);
}
}
31 changes: 31 additions & 0 deletions atrium-common/src/resolver/cached.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::hash::Hash;

use crate::types::cached::r#impl::{Cache, CacheImpl};
use crate::types::cached::Cached;

use super::Resolver;

pub type CachedResolver<R> = Cached<R, CacheImpl<<R as Resolver>::Input, <R as Resolver>::Output>>;

impl<R, C> Resolver for Cached<R, C>
where
R: Resolver + Send + Sync + 'static,
R::Input: Clone + Hash + Eq + Send + Sync + 'static,
R::Output: Clone + Send + Sync + 'static,
C: Cache<Input = R::Input, Output = R::Output> + Send + Sync + 'static,
C::Input: Clone + Hash + Eq + Send + Sync + 'static,
C::Output: Clone + Send + Sync + 'static,
{
type Input = R::Input;
type Output = R::Output;
type Error = R::Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
if let Some(output) = self.cache.get(input).await {
return Ok(output);
}
let output = self.inner.resolve(input).await?;
self.cache.set(input.clone(), output.clone()).await;
Ok(output)
}
}
43 changes: 43 additions & 0 deletions atrium-common/src/resolver/throttled.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::{hash::Hash, sync::Arc};

use dashmap::{DashMap, Entry};
use tokio::sync::broadcast::{channel, Sender};
use tokio::sync::Mutex;

use crate::types::throttled::Throttled;

use super::Resolver;

pub type SenderMap<R> =
DashMap<<R as Resolver>::Input, Arc<Mutex<Sender<Option<<R as Resolver>::Output>>>>>;

pub type ThrottledResolver<R> = Throttled<R, SenderMap<R>>;

impl<R> Resolver for Throttled<R, SenderMap<R>>
where
R: Resolver + Send + Sync + 'static,
R::Input: Clone + Hash + Eq + Send + Sync + 'static,
R::Output: Clone + Send + Sync + 'static,
{
type Input = R::Input;
type Output = Option<R::Output>;
type Error = R::Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
match self.pending.entry(input.clone()) {
Entry::Occupied(occupied) => {
let tx = occupied.get().lock().await.clone();
drop(occupied);
Ok(tx.subscribe().recv().await.expect("recv"))
}
Entry::Vacant(vacant) => {
let (tx, _) = channel(1);
vacant.insert(Arc::new(Mutex::new(tx.clone())));
let result = self.inner.resolve(input).await;
tx.send(result.as_ref().ok().cloned()).ok();
self.pending.remove(input);
result.map(Some)
}
}
}
}
Loading

0 comments on commit de18c4f

Please sign in to comment.