From b5986c02852639dc8bc47d4362da872cb3375174 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 30 Jan 2024 17:06:54 +0800 Subject: [PATCH] add getter methods for RequestStream. (#913) --- client/Cargo.toml | 2 +- client/src/ws.rs | 48 ++++++++++++++++++----------------------- http-ws/CHANGES.md | 3 +++ http-ws/src/stream.rs | 10 +++++++++ test/tests/websocket.rs | 33 +++++++++++++++------------- 5 files changed, 53 insertions(+), 43 deletions(-) diff --git a/client/Cargo.toml b/client/Cargo.toml index d1ad8a27..49376f6b 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -67,7 +67,7 @@ serde_json = { version = "1", optional = true } # websocket support futures-sink = { version = "0.3.17", default-features = false, optional = true } -http-ws = { version = "0.2", default-features = false, optional = true } +http-ws = { version = "0.2", features = ["stream"], optional = true } [dev-dependencies] futures = "0.3" diff --git a/client/src/ws.rs b/client/src/ws.rs index db98be32..7a46a13d 100644 --- a/client/src/ws.rs +++ b/client/src/ws.rs @@ -13,10 +13,11 @@ use std::sync::Mutex; use futures_core::stream::Stream; use futures_sink::Sink; -use http_ws::Codec; +use http_ws::{Codec, RequestStream, WsError}; use xitca_http::bytes::{Buf, BytesMut}; use super::{ + body::BodyError, body::ResponseBody, error::Error, http::{Method, Version}, @@ -130,8 +131,7 @@ impl<'a> WebSocket<'a> { inner: Mutex::new(WebSocketInner { codec: Codec::new().client_mode(), send_buf: BytesMut::new(), - recv_buf: BytesMut::new(), - body, + recv_stream: RequestStream::with_codec(body, Codec::new().client_mode()), }), }) } @@ -147,7 +147,10 @@ impl<'a> WebSocket<'a> { /// /// By default max size is set to 64kB. pub fn max_size(mut self, size: usize) -> Self { - self.inner.get_mut().unwrap().codec.set_max_size(size); + let inner = self.inner.get_mut().unwrap(); + inner.codec = inner.codec.set_max_size(size); + let recv_codec = inner.recv_stream.codec_mut(); + *recv_codec = recv_codec.set_max_size(size); self } @@ -192,8 +195,7 @@ impl Stream for WebSocket<'_> { struct WebSocketInner<'b> { codec: Codec, send_buf: BytesMut, - recv_buf: BytesMut, - body: ResponseBody<'b>, + recv_stream: RequestStream, BodyError>, } impl Sink for WebSocketInner<'_> { @@ -217,9 +219,9 @@ impl Sink for WebSocketInner<'_> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let inner = self.get_mut(); - match inner.body { + match inner.recv_stream.inner_mut() { #[cfg(feature = "http1")] - ResponseBody::H1(ref mut body) => { + ResponseBody::H1(body) => { use std::io; use tokio::io::AsyncWrite; while !inner.send_buf.chunk().is_empty() { @@ -232,7 +234,7 @@ impl Sink for WebSocketInner<'_> { Pin::new(&mut **body.conn()).poll_flush(_cx).map_err(Into::into) } #[cfg(feature = "http2")] - ResponseBody::H2(ref mut body) => { + ResponseBody::H2(body) => { while !inner.send_buf.chunk().is_empty() { ready!(body.poll_send_buf(&mut inner.send_buf, _cx))?; } @@ -245,13 +247,13 @@ impl Sink for WebSocketInner<'_> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; - match self.get_mut().body { + match self.get_mut().recv_stream.inner_mut() { #[cfg(feature = "http1")] - ResponseBody::H1(ref mut body) => { + ResponseBody::H1(body) => { tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut **body.conn()), cx).map_err(Into::into) } #[cfg(feature = "http2")] - ResponseBody::H2(ref mut body) => { + ResponseBody::H2(body) => { body.send_data(xitca_http::bytes::Bytes::new(), true)?; Poll::Ready(Ok(())) } @@ -263,21 +265,13 @@ impl Sink for WebSocketInner<'_> { impl Stream for WebSocketInner<'_> { type Item = Result; + #[inline] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - loop { - if let Some(msg) = this.codec.decode(&mut this.recv_buf)? { - return Poll::Ready(Some(Ok(msg))); - } - - match ready!(Pin::new(&mut this.body).poll_next(cx)) { - Some(res) => { - let bytes = res?; - this.recv_buf.extend_from_slice(&bytes); - } - None => return Poll::Ready(None), - } - } + Pin::new(&mut self.get_mut().recv_stream) + .poll_next(cx) + .map_err(|e| match e { + WsError::Protocol(e) => Error::from(e), + WsError::Stream(e) => Error::Std(e), + }) } } diff --git a/http-ws/CHANGES.md b/http-ws/CHANGES.md index aed76790..6bec2ee2 100644 --- a/http-ws/CHANGES.md +++ b/http-ws/CHANGES.md @@ -1,4 +1,7 @@ # unreleased +## Add +- add `RequestStream::inner_mut` method for accessing inner stream type. +- add `RequestStream::codec_mut` method for accessing `Codec` type. # 0.2.0 ## Add diff --git a/http-ws/src/stream.rs b/http-ws/src/stream.rs index dbf927c3..056e20fa 100644 --- a/http-ws/src/stream.rs +++ b/http-ws/src/stream.rs @@ -50,6 +50,16 @@ where } } + #[inline] + pub fn inner_mut(&mut self) -> &mut S { + &mut self.stream + } + + #[inline] + pub fn codec_mut(&mut self) -> &mut Codec { + &mut self.codec + } + /// Make a [ResponseStream] from current DecodeStream. /// /// This API is to share the same codec for both decode and encode stream. diff --git a/test/tests/websocket.rs b/test/tests/websocket.rs index 4e8c0038..fc6a5492 100644 --- a/test/tests/websocket.rs +++ b/test/tests/websocket.rs @@ -43,27 +43,30 @@ async fn message_h2() -> Result<(), Error> { let server_url = format!("wss://{}/", handle.ip_port_string()); - let c = Client::new(); - let (mut tx, mut rx) = c.ws2(&server_url)?.send().await?.split(); + { + let c = Client::new(); + let (mut tx, mut rx) = c.ws2(&server_url)?.send().await?.split(); - for _ in 0..9 { - tx.send(Message::Text(Bytes::from("Hello,World!"))).await?; - } + for _ in 0..9 { + tx.send(Message::Text(Bytes::from("Hello,World!"))).await?; + } - for _ in 0..9 { - let msg = rx.next().await.unwrap()?; - assert_eq!(msg, Message::Text(Bytes::from("Hello,World!"))); - } + for _ in 0..9 { + let msg = rx.next().await.unwrap()?; + assert_eq!(msg, Message::Text(Bytes::from("Hello,World!"))); + } - tx.send(Message::Ping(Bytes::from("pingpong"))).await?; - let msg = rx.next().await.unwrap()?; - assert_eq!(msg, Message::Pong(Bytes::from("pingpong"))); + tx.send(Message::Ping(Bytes::from("pingpong"))).await?; + let msg = rx.next().await.unwrap()?; + assert_eq!(msg, Message::Pong(Bytes::from("pingpong"))); - tx.send(Message::Close(None)).await?; - let msg = rx.next().await.unwrap()?; - assert_eq!(msg, Message::Close(None)); + tx.send(Message::Close(None)).await?; + let msg = rx.next().await.unwrap()?; + assert_eq!(msg, Message::Close(None)); + } handle.try_handle()?.stop(true); + tokio::task::yield_now().await; handle.await.map_err(Into::into) }