diff --git a/examples/examples/axum.rs b/examples/examples/axum.rs index 8db2503..dad99d6 100644 --- a/examples/examples/axum.rs +++ b/examples/examples/axum.rs @@ -8,10 +8,10 @@ use hitbox_tower::{ use hitbox_redis::RedisBackend; use hitbox_tower::{Cache, EndpointConfig}; -async fn handler_result(Path(_name): Path) -> Result { +async fn handler_result(Path(_name): Path) -> Result { //dbg!("axum::handler_result"); // Ok(format!("Hello, {name}")) - Err("error".to_owned()) + Err(StatusCode::INTERNAL_SERVER_ERROR) } async fn handler() -> String { diff --git a/examples/examples/tower.rs b/examples/examples/tower.rs index c866dc5..864f246 100644 --- a/examples/examples/tower.rs +++ b/examples/examples/tower.rs @@ -4,12 +4,12 @@ use hitbox_tower::Cache; use hyper::{Body, Server}; use std::net::SocketAddr; -use http::{Request, Response}; +use http::{Method, Request, Response}; use tower::make::Shared; -async fn handle(_: Request) -> Result, String> { - Ok(Response::new("Hello, World!".into())) - // Err("handler error".to_owned()) +async fn handle(_: Request) -> http::Result> { + // Ok(Response::new("Hello, World!".into())) + Err(http::Error::from(Method::from_bytes(&[0x01]).unwrap_err())) } #[tokio::main] @@ -24,7 +24,7 @@ async fn main() { let redis = RedisBackend::builder().build().unwrap(); let service = tower::ServiceBuilder::new() - .layer(Cache::builder().backend(inmemory).build()) + // .layer(Cache::builder().backend(inmemory).build()) .layer(Cache::builder().backend(redis).build()) .service_fn(handle); diff --git a/hitbox-backend/src/serializer.rs b/hitbox-backend/src/serializer.rs index 227c836..f14016f 100644 --- a/hitbox-backend/src/serializer.rs +++ b/hitbox-backend/src/serializer.rs @@ -129,7 +129,7 @@ impl Serializer for BinSerializer> { #[cfg(test)] mod test { use async_trait::async_trait; - use hitbox_core::CacheableResponse; + use hitbox_core::{CachePolicy, CacheableResponse, PredicateResult}; use super::*; @@ -142,9 +142,25 @@ mod test { #[async_trait] impl CacheableResponse for Test { type Cached = Self; + type Subject = Self; + + async fn cache_policy

(self, predicates: P) -> hitbox_core::ResponseCachePolicy + where + P: hitbox_core::Predicate + Send + Sync, + { + match predicates.check(self).await { + PredicateResult::Cacheable(cacheable) => match cacheable.into_cached().await { + CachePolicy::Cacheable(res) => { + CachePolicy::Cacheable(CachedValue::new(res, Utc::now())) + } + CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(res), + }, + PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(res), + } + } - async fn into_cached(self) -> Self::Cached { - self + async fn into_cached(self) -> CachePolicy { + CachePolicy::Cacheable(self) } async fn from_cached(cached: Self::Cached) -> Self { cached diff --git a/hitbox-core/Cargo.toml b/hitbox-core/Cargo.toml index 6439544..471b263 100644 --- a/hitbox-core/Cargo.toml +++ b/hitbox-core/Cargo.toml @@ -17,3 +17,9 @@ keywords = ["cache", "async", "cache-backend", "hitbox", "tokio"] [dependencies] async-trait = "0.1.73" chrono = { version = "0.4.26", default-features = false, features = ["clock"] } + +[dev-dependencies] +tokio = { version = "1", default-features = false, features = [ + "macros", + "rt-multi-thread", +] } diff --git a/hitbox-core/src/response.rs b/hitbox-core/src/response.rs index 35c3e9d..4c0bb7f 100644 --- a/hitbox-core/src/response.rs +++ b/hitbox-core/src/response.rs @@ -34,20 +34,13 @@ where Self::Cached: Clone, { type Cached; + type Subject: CacheableResponse; async fn cache_policy

(self, predicates: P) -> ResponseCachePolicy where - P: Predicate + Send + Sync, - { - match predicates.check(self).await { - PredicateResult::Cacheable(res) => { - CachePolicy::Cacheable(CachedValue::new(res.into_cached().await, Utc::now())) - } - PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(res), - } - } + P: Predicate + Send + Sync; - async fn into_cached(self) -> Self::Cached; + async fn into_cached(self) -> CachePolicy; async fn from_cached(cached: Self::Cached) -> Self; } @@ -60,12 +53,37 @@ where T::Cached: Send, { type Cached = ::Cached; + type Subject = T; + + async fn cache_policy

(self, predicates: P) -> ResponseCachePolicy + where + P: Predicate + Send + Sync, + { + match self { + Ok(response) => match predicates.check(response).await { + PredicateResult::Cacheable(cacheable) => match cacheable.into_cached().await { + CachePolicy::Cacheable(res) => { + CachePolicy::Cacheable(CachedValue::new(res, Utc::now())) + } + CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(Ok(res)), + }, + PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(Ok(res)), + }, + Err(error) => ResponseCachePolicy::NonCacheable(Err(error)), + } + } - async fn into_cached(self) -> Self::Cached { - unimplemented!() + async fn into_cached(self) -> CachePolicy { + match self { + Ok(response) => match response.into_cached().await { + CachePolicy::Cacheable(res) => CachePolicy::Cacheable(res), + CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(Ok(res)), + }, + Err(error) => CachePolicy::NonCacheable(Err(error)), + } } - async fn from_cached(_cached: Self::Cached) -> Self { - unimplemented!() + async fn from_cached(cached: Self::Cached) -> Self { + Ok(T::from_cached(cached).await) } } diff --git a/hitbox-core/tests/response.rs b/hitbox-core/tests/response.rs new file mode 100644 index 0000000..53b0e33 --- /dev/null +++ b/hitbox-core/tests/response.rs @@ -0,0 +1,74 @@ +use async_trait::async_trait; +use chrono::Utc; +use hitbox_core::{CachePolicy, CacheableResponse, CachedValue, Predicate, PredicateResult}; + +#[derive(Clone, Debug)] +struct TestResponse { + field1: String, + field2: u8, +} + +impl TestResponse { + pub fn new() -> Self { + Self { + field1: "nope".to_owned(), + field2: 42, + } + } +} + +#[async_trait] +impl CacheableResponse for TestResponse { + type Cached = Self; + type Subject = Self; + + async fn cache_policy

(self, predicates: P) -> hitbox_core::ResponseCachePolicy + where + P: hitbox_core::Predicate + Send + Sync, + { + match predicates.check(self).await { + PredicateResult::Cacheable(cacheable) => match cacheable.into_cached().await { + CachePolicy::Cacheable(res) => { + CachePolicy::Cacheable(CachedValue::new(res, Utc::now())) + } + CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(res), + }, + PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(res), + } + } + + async fn into_cached(self) -> CachePolicy { + CachePolicy::Cacheable(self) + } + async fn from_cached(cached: Self::Cached) -> Self { + cached + } +} + +struct NeuralPredicate {} + +impl NeuralPredicate { + fn new() -> Self { + NeuralPredicate {} + } +} + +#[async_trait] +impl Predicate for NeuralPredicate { + type Subject = TestResponse; + + async fn check(&self, subject: Self::Subject) -> PredicateResult { + PredicateResult::Cacheable(subject) + } +} + +#[tokio::test] +async fn test_cacheable_result() { + let response: Result = Ok(TestResponse::new()); + let policy = response.cache_policy(NeuralPredicate::new()).await; + dbg!(&policy); + + let response: Result = Err(()); + let policy = response.cache_policy(NeuralPredicate::new()).await; + dbg!(&policy); +} diff --git a/hitbox-http/src/response.rs b/hitbox-http/src/response.rs index 346deca..0b4f677 100644 --- a/hitbox-http/src/response.rs +++ b/hitbox-http/src/response.rs @@ -2,7 +2,8 @@ use std::{collections::HashMap, fmt::Debug}; use async_trait::async_trait; use bytes::Bytes; -use hitbox::CacheableResponse; +use chrono::Utc; +use hitbox::{predicate::PredicateResult, CachePolicy, CacheableResponse, CachedValue}; use http::{response::Parts, Response}; use hyper::body::{to_bytes, HttpBody}; use serde::{Deserialize, Serialize}; @@ -71,9 +72,25 @@ where ResBody::Data: Send, { type Cached = SerializableHttpResponse; + type Subject = Self; - async fn into_cached(self) -> Self::Cached { - SerializableHttpResponse { + async fn cache_policy

(self, predicates: P) -> hitbox::ResponseCachePolicy + where + P: hitbox::Predicate + Send + Sync, + { + match predicates.check(self).await { + PredicateResult::Cacheable(cacheable) => match cacheable.into_cached().await { + CachePolicy::Cacheable(res) => { + CachePolicy::Cacheable(CachedValue::new(res, Utc::now())) + } + CachePolicy::NonCacheable(res) => CachePolicy::NonCacheable(res), + }, + PredicateResult::NonCacheable(res) => CachePolicy::NonCacheable(res), + } + } + + async fn into_cached(self) -> CachePolicy { + CachePolicy::Cacheable(SerializableHttpResponse { status: 200, version: "HTTP/1.1".to_owned(), body: to_bytes(self.body.into_inner_body()) @@ -86,7 +103,7 @@ where .into_iter() .map(|(h, v)| (h.unwrap().to_string(), v.to_str().unwrap().to_string())) .collect(), - } + }) } async fn from_cached(cached: Self::Cached) -> Self { diff --git a/hitbox-tower/src/future.rs b/hitbox-tower/src/future.rs index 7553c34..34787a7 100644 --- a/hitbox-tower/src/future.rs +++ b/hitbox-tower/src/future.rs @@ -1,4 +1,6 @@ use std::{ + any::type_name, + fmt::Debug, marker::PhantomData, pin::Pin, task::{Context, Poll}, @@ -25,52 +27,68 @@ impl Transformer { } } -impl Transform, CacheableHttpResponse> +impl + Transform, Result, S::Error>> for Transformer where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send, ReqBody: Send + 'static, ResBody: FromBytes, + + // debug bounds + S::Error: Debug, { - type Future = UpstreamFuture; + type Future = UpstreamFuture; type Response = Result, S::Error>; fn upstream_transform(&self, req: CacheableHttpRequest) -> Self::Future { UpstreamFuture::new(self.inner.clone(), req) } - fn response_transform(&self, res: CacheableHttpResponse) -> Self::Response { - Ok(res.into_response()) + fn response_transform( + &self, + res: Result, S::Error>, + ) -> Self::Response { + res.map(CacheableHttpResponse::into_response) } } #[pin_project] -pub struct UpstreamFuture { - inner_future: BoxFuture<'static, CacheableHttpResponse>, +pub struct UpstreamFuture { + inner_future: BoxFuture<'static, Result, E>>, } -impl UpstreamFuture { +impl UpstreamFuture { pub fn new(mut inner_service: S, req: CacheableHttpRequest) -> Self where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response, Error = E> + Send + 'static, S::Future: Send, ReqBody: Send + 'static, ResBody: FromBytes, + + // debug bounds + S::Error: Debug, { let inner_future = Box::pin(async move { let res = inner_service.call(req.into_request()).await; - match res { - Ok(res) => CacheableHttpResponse::from_response(res), - _ => unimplemented!(), - } + // CacheableHttpResponse::from_response(res.unwrap()) + match &res { + Ok(res) => { + dbg!(res.status()); + } + Err(err) => { + dbg!(err); + } + }; + res.map(CacheableHttpResponse::from_response) }); UpstreamFuture { inner_future } } } -impl Future for UpstreamFuture { - type Output = CacheableHttpResponse; +impl Future for UpstreamFuture { + type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); this.inner_future.as_mut().poll(cx) diff --git a/hitbox-tower/src/service.rs b/hitbox-tower/src/service.rs index df485f1..10c966d 100644 --- a/hitbox-tower/src/service.rs +++ b/hitbox-tower/src/service.rs @@ -53,13 +53,14 @@ where ResBody: FromBytes + HttpBody + Send + 'static, ResBody::Error: Debug, ResBody::Data: Send, + S::Error: Debug + Send, { type Response = Response; type Error = S::Error; type Future = CacheFuture< B, CacheableHttpRequest, - CacheableHttpResponse, + Result, S::Error>, Transformer, >; diff --git a/hitbox/src/fsm/future.rs b/hitbox/src/fsm/future.rs index 38e0e26..fc3a466 100644 --- a/hitbox/src/fsm/future.rs +++ b/hitbox/src/fsm/future.rs @@ -15,7 +15,7 @@ use tracing::debug; use crate::{ backend::CacheBackend, - fsm::{states::StateProj, PollCache, State}, + fsm::{states::StateProj, PollCacheFuture, State}, CacheKey, CacheableRequest, Extractor, Predicate, }; @@ -166,11 +166,11 @@ where request: Option, cache_key: Option, #[pin] - state: State<::Output, Res, Req>, + state: State, #[pin] - poll_cache: Option>, + poll_cache: Option>, request_predicates: Arc + Send + Sync>, - response_predicates: Arc + Send + Sync>, + response_predicates: Arc + Send + Sync>, key_extractors: Arc + Send + Sync>, policy: Arc, } @@ -188,7 +188,7 @@ where request: Req, transformer: T, request_predicates: Arc + Send + Sync>, - response_predicates: Arc + Send + Sync>, + response_predicates: Arc + Send + Sync>, key_extractors: Arc + Send + Sync>, policy: Arc, ) -> Self { diff --git a/hitbox/src/fsm/mod.rs b/hitbox/src/fsm/mod.rs index 73b1f4e..457cfce 100644 --- a/hitbox/src/fsm/mod.rs +++ b/hitbox/src/fsm/mod.rs @@ -2,4 +2,4 @@ mod future; mod states; pub use future::{CacheFuture, Transform}; -pub use states::{PollCache, State, UpdateCache}; +pub use states::{PollCacheFuture, State, UpdateCache}; diff --git a/hitbox/src/fsm/states.rs b/hitbox/src/fsm/states.rs index c5aa596..e9fb0c9 100644 --- a/hitbox/src/fsm/states.rs +++ b/hitbox/src/fsm/states.rs @@ -7,54 +7,57 @@ use pin_project::pin_project; use crate::{CacheState, CacheableResponse, CachedValue}; -pub type CacheResult = Result>, BackendError>; -pub type PollCache = BoxFuture<'static, CacheResult>; -pub type UpdateCache = BoxFuture<'static, (Result<(), BackendError>, R)>; +pub type CacheResult = Result>, BackendError>; +pub type PollCacheFuture = BoxFuture<'static, CacheResult>; +pub type UpdateCache = BoxFuture<'static, (Result<(), BackendError>, T)>; +pub type RequestCachePolicyFuture = BoxFuture<'static, RequestCachePolicy>; +pub type CacheStateFuture = BoxFuture<'static, CacheState>; +pub type UpstreamFuture = BoxFuture<'static, T>; #[allow(missing_docs)] #[pin_project(project = StateProj)] -pub enum State +pub enum State where - C: CacheableResponse, + Res: CacheableResponse, { Initial, CheckRequestCachePolicy { #[pin] - cache_policy_future: BoxFuture<'static, RequestCachePolicy>, + cache_policy_future: RequestCachePolicyFuture, }, PollCache { #[pin] - poll_cache: PollCache, - request: Option, + poll_cache: PollCacheFuture, + request: Option, }, // CachePolled { // cache_result: CacheResult, // }, CheckCacheState { - cache_state: BoxFuture<'static, CacheState>, + cache_state: CacheStateFuture, }, PollUpstream { - upstream_future: BoxFuture<'static, C>, + upstream_future: UpstreamFuture, }, UpstreamPolled { - upstream_result: Option, + upstream_result: Option, }, CheckResponseCachePolicy { #[pin] - cache_policy: BoxFuture<'static, ResponseCachePolicy>, + cache_policy: BoxFuture<'static, ResponseCachePolicy>, }, UpdateCache { #[pin] - update_cache_future: UpdateCache, + update_cache_future: UpdateCache, }, Response { - response: Option, + response: Option, }, } -impl Debug for State +impl Debug for State where - C: CacheableResponse, + Res: CacheableResponse, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self {