Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for chaining multiple tower-http middlewares. #889

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions web/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@
```
- update `xitca-http` to version `0.2.0`.
- update `http-encoding` to version `0.2.0`.

## Fix
- fix nested App routing. `App::new().at("/foo", App::new().at("/"))` would be successfully matching against `/foo/`
- fix bug where certain tower-http layers causing compile issue.
- fix bug where multiple tower-http layers can't be chained together with `ServiceExt::enclosed`.
73 changes: 17 additions & 56 deletions web/src/middleware/tower_http_compat.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use core::{
cell::RefCell,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};

use std::rc::Rc;

use tower_layer::Layer;
use xitca_unsafe_collection::fake::{FakeClone, FakeSend, FakeSync};

use crate::{
context::WebContext,
Expand All @@ -28,24 +26,10 @@ use crate::{
/// by it must be able to handle it's mutation or utilize [TypeEraser] to erase the mutation.
///
/// [TypeEraser]: crate::middleware::eraser::TypeEraser
pub struct TowerHttpCompat<L, C, ReqB, ResB, Err> {
layer: L,
_phantom: PhantomData<fn(C, ReqB, ResB, Err)>,
}

impl<L, C, ReqB, ResB, Err> Clone for TowerHttpCompat<L, C, ReqB, ResB, Err>
where
L: Clone,
{
fn clone(&self) -> Self {
Self {
layer: self.layer.clone(),
_phantom: PhantomData,
}
}
}
#[derive(Clone)]
pub struct TowerHttpCompat<L>(L);

impl<L, C, ReqB, ResB, Err> TowerHttpCompat<L, C, ReqB, ResB, Err> {
impl<L> TowerHttpCompat<L> {
/// Construct a new xitca-web middleware from tower-http layer type.
///
/// # Limitation:
Expand All @@ -69,41 +53,29 @@ impl<L, C, ReqB, ResB, Err> TowerHttpCompat<L, C, ReqB, ResB, Err> {
/// # todo!()
/// # }
/// ```
pub fn new(layer: L) -> Self {
Self {
layer,
_phantom: PhantomData,
}
pub const fn new(layer: L) -> Self {
Self(layer)
}
}

impl<L, S, E, C, ReqB, ResB, Err> Service<Result<S, E>> for TowerHttpCompat<L, C, ReqB, ResB, Err>
impl<L, S, E> Service<Result<S, E>> for TowerHttpCompat<L>
where
L: Layer<CompatLayer<S, C, ResB, Err>>,
S: for<'r> Service<WebContext<'r, C, ReqB>, Response = WebResponse<ResB>, Error = Err>,
ReqB: 'static,
L: Layer<CompatLayer<S>>,
{
type Response = TowerCompatService<L::Service>;
type Error = E;

async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
res.map(|service| {
let service = self.layer.layer(CompatLayer {
service: Rc::new(service),
_phantom: PhantomData,
});
let service = self.0.layer(CompatLayer(Rc::new(service)));
TowerCompatService::new(service)
})
}
}

pub struct CompatLayer<S, C, ResB, Err> {
service: Rc<S>,
_phantom: PhantomData<fn(C, ResB, Err)>,
}
pub struct CompatLayer<S>(Rc<S>);

impl<S, C, ReqB, ResB, Err> tower_service::Service<Request<CompatReqBody<RequestExt<ReqB>>>>
for CompatLayer<S, C, ResB, Err>
impl<S, C, ReqB, ResB, Err> tower_service::Service<Request<CompatReqBody<RequestExt<ReqB>, C>>> for CompatLayer<S>
where
S: for<'r> Service<WebContext<'r, C, ReqB>, Response = WebResponse<ResB>, Error = Err> + 'static,
C: Clone + 'static,
Expand All @@ -118,20 +90,12 @@ where
Poll::Ready(Ok(()))
}

fn call(&mut self, req: Request<CompatReqBody<RequestExt<ReqB>>>) -> Self::Future {
let service = self.service.clone();
fn call(&mut self, req: Request<CompatReqBody<RequestExt<ReqB>, C>>) -> Self::Future {
let service = self.0.clone();
Box::pin(async move {
let (mut parts, body) = req.into_parts();

let ctx = parts
.extensions
.remove::<FakeClone<FakeSync<FakeSend<C>>>>()
.unwrap()
.into_inner()
.into_inner()
.into_inner();

let (ext, body) = body.into_inner().replace_body(());
let (parts, body) = req.into_parts();
let (body, ctx) = body.into_parts();
let (ext, body) = body.replace_body(());

let mut req = Request::from_parts(parts, ext);
let mut body = RefCell::new(body);
Expand All @@ -149,10 +113,7 @@ mod test {
use tower_http::set_status::SetStatusLayer;
use xitca_unsafe_collection::futures::NowOrPanic;

use crate::{
body::ResponseBody, http::StatusCode, http::WebRequest, middleware::eraser::TypeEraser, service::fn_service,
App,
};
use crate::{body::ResponseBody, http::StatusCode, http::WebRequest, service::fn_service, App};

use super::*;

Expand All @@ -166,8 +127,8 @@ mod test {
let res = App::new()
.with_state("996")
.at("/", fn_service(handler))
.enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::OK)))
.enclosed(TowerHttpCompat::new(SetStatusLayer::new(StatusCode::NOT_FOUND)))
.enclosed(TypeEraser::response_body())
.finish()
.call(())
.now_or_panic()
Expand Down
55 changes: 24 additions & 31 deletions web/src/service/tower_http_compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use xitca_http::{
body::{none_body_hint, BodySize},
util::service::router::{RouterGen, RouterMapErr},
};
use xitca_unsafe_collection::fake::{FakeClone, FakeSend, FakeSync};
use xitca_unsafe_collection::fake::FakeSend;

use crate::{
bytes::Buf,
Expand All @@ -23,16 +23,14 @@ use crate::{

/// A middleware type that bridge `xitca-service` and `tower-service`.
/// Any `tower-http` type that impl [tower_service::Service] trait can be passed to it and used as xitca-web's service type.
pub struct TowerHttpCompat<S> {
service: S,
}
pub struct TowerHttpCompat<S>(S);

impl<S> TowerHttpCompat<S> {
pub const fn new(service: S) -> Self
where
S: Clone,
{
Self { service }
Self(service)
}
}

Expand All @@ -44,11 +42,8 @@ where
type Error = Infallible;

async fn call(&self, _: ()) -> Result<Self::Response, Self::Error> {
let service = self.service.clone();

Ok(TowerCompatService {
service: RefCell::new(service),
})
let service = self.0.clone();
Ok(TowerCompatService(RefCell::new(service)))
}
}

Expand All @@ -60,21 +55,17 @@ impl<S> RouterGen for TowerHttpCompat<S> {
}
}

pub struct TowerCompatService<S> {
service: RefCell<S>,
}
pub struct TowerCompatService<S>(RefCell<S>);

impl<S> TowerCompatService<S> {
pub fn new(service: S) -> Self {
Self {
service: RefCell::new(service),
}
pub const fn new(service: S) -> Self {
Self(RefCell::new(service))
}
}

impl<'r, C, ReqB, S, ResB> Service<WebContext<'r, C, ReqB>> for TowerCompatService<S>
where
S: tower_service::Service<Request<CompatReqBody<RequestExt<ReqB>>>, Response = Response<ResB>>,
S: tower_service::Service<Request<CompatReqBody<RequestExt<ReqB>, C>>, Response = Response<ResB>>,
ResB: Body,
C: Clone + 'static,
ReqB: Default,
Expand All @@ -83,13 +74,10 @@ where
type Error = S::Error;

async fn call(&self, mut ctx: WebContext<'r, C, ReqB>) -> Result<Self::Response, Self::Error> {
let state = ctx.state().clone();
let (mut parts, ext) = ctx.take_request().into_parts();
parts
.extensions
.insert(FakeClone::new(FakeSync::new(FakeSend::new(state))));
let req = Request::from_parts(parts, CompatReqBody::new(ext));
let fut = tower_service::Service::call(&mut *self.service.borrow_mut(), req);
let (parts, ext) = ctx.take_request().into_parts();
let ctx = ctx.state().clone();
let req = Request::from_parts(parts, CompatReqBody::new(ext, ctx));
let fut = tower_service::Service::call(&mut *self.0.borrow_mut(), req);
fut.await.map(|res| res.map(CompatResBody::new))
}
}
Expand All @@ -101,25 +89,30 @@ impl<S> ReadyService for TowerCompatService<S> {
async fn ready(&self) -> Self::Ready {}
}

pub struct CompatReqBody<B> {
pub struct CompatReqBody<B, C> {
body: FakeSend<B>,
ctx: C,
}

impl<B> CompatReqBody<B> {
pub fn new(body: B) -> Self {
impl<B, C> CompatReqBody<B, C> {
#[inline]
pub fn new(body: B, ctx: C) -> Self {
Self {
body: FakeSend::new(body),
ctx,
}
}

pub fn into_inner(self) -> B {
self.body.into_inner()
#[inline]
pub fn into_parts(self) -> (B, C) {
(self.body.into_inner(), self.ctx)
}
}

impl<B, T, E> Body for CompatReqBody<B>
impl<B, C, T, E> Body for CompatReqBody<B, C>
where
B: Stream<Item = Result<T, E>> + Unpin,
C: Unpin,
T: Buf,
{
type Data = T;
Expand Down
Loading