Skip to content

Commit

Permalink
add getter methods for RequestStream. (#913)
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow authored Jan 30, 2024
1 parent 3cd099d commit b5986c0
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 43 deletions.
2 changes: 1 addition & 1 deletion client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ serde_json = { version = "1", optional = true }

# websocket support
futures-sink = { version = "0.3.17", default-features = false, optional = true }
http-ws = { version = "0.2", default-features = false, optional = true }
http-ws = { version = "0.2", features = ["stream"], optional = true }

[dev-dependencies]
futures = "0.3"
Expand Down
48 changes: 21 additions & 27 deletions client/src/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use std::sync::Mutex;

use futures_core::stream::Stream;
use futures_sink::Sink;
use http_ws::Codec;
use http_ws::{Codec, RequestStream, WsError};
use xitca_http::bytes::{Buf, BytesMut};

use super::{
body::BodyError,
body::ResponseBody,
error::Error,
http::{Method, Version},
Expand Down Expand Up @@ -130,8 +131,7 @@ impl<'a> WebSocket<'a> {
inner: Mutex::new(WebSocketInner {
codec: Codec::new().client_mode(),
send_buf: BytesMut::new(),
recv_buf: BytesMut::new(),
body,
recv_stream: RequestStream::with_codec(body, Codec::new().client_mode()),
}),
})
}
Expand All @@ -147,7 +147,10 @@ impl<'a> WebSocket<'a> {
///
/// By default max size is set to 64kB.
pub fn max_size(mut self, size: usize) -> Self {
self.inner.get_mut().unwrap().codec.set_max_size(size);
let inner = self.inner.get_mut().unwrap();
inner.codec = inner.codec.set_max_size(size);
let recv_codec = inner.recv_stream.codec_mut();
*recv_codec = recv_codec.set_max_size(size);
self
}

Expand Down Expand Up @@ -192,8 +195,7 @@ impl Stream for WebSocket<'_> {
struct WebSocketInner<'b> {
codec: Codec,
send_buf: BytesMut,
recv_buf: BytesMut,
body: ResponseBody<'b>,
recv_stream: RequestStream<ResponseBody<'b>, BodyError>,
}

impl Sink<Message> for WebSocketInner<'_> {
Expand All @@ -217,9 +219,9 @@ impl Sink<Message> for WebSocketInner<'_> {
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let inner = self.get_mut();

match inner.body {
match inner.recv_stream.inner_mut() {
#[cfg(feature = "http1")]
ResponseBody::H1(ref mut body) => {
ResponseBody::H1(body) => {
use std::io;
use tokio::io::AsyncWrite;
while !inner.send_buf.chunk().is_empty() {
Expand All @@ -232,7 +234,7 @@ impl Sink<Message> for WebSocketInner<'_> {
Pin::new(&mut **body.conn()).poll_flush(_cx).map_err(Into::into)
}
#[cfg(feature = "http2")]
ResponseBody::H2(ref mut body) => {
ResponseBody::H2(body) => {
while !inner.send_buf.chunk().is_empty() {
ready!(body.poll_send_buf(&mut inner.send_buf, _cx))?;
}
Expand All @@ -245,13 +247,13 @@ impl Sink<Message> for WebSocketInner<'_> {

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
match self.get_mut().body {
match self.get_mut().recv_stream.inner_mut() {
#[cfg(feature = "http1")]
ResponseBody::H1(ref mut body) => {
ResponseBody::H1(body) => {
tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut **body.conn()), cx).map_err(Into::into)
}
#[cfg(feature = "http2")]
ResponseBody::H2(ref mut body) => {
ResponseBody::H2(body) => {
body.send_data(xitca_http::bytes::Bytes::new(), true)?;
Poll::Ready(Ok(()))
}
Expand All @@ -263,21 +265,13 @@ impl Sink<Message> for WebSocketInner<'_> {
impl Stream for WebSocketInner<'_> {
type Item = Result<Message, Error>;

#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

loop {
if let Some(msg) = this.codec.decode(&mut this.recv_buf)? {
return Poll::Ready(Some(Ok(msg)));
}

match ready!(Pin::new(&mut this.body).poll_next(cx)) {
Some(res) => {
let bytes = res?;
this.recv_buf.extend_from_slice(&bytes);
}
None => return Poll::Ready(None),
}
}
Pin::new(&mut self.get_mut().recv_stream)
.poll_next(cx)
.map_err(|e| match e {
WsError::Protocol(e) => Error::from(e),
WsError::Stream(e) => Error::Std(e),
})
}
}
3 changes: 3 additions & 0 deletions http-ws/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# unreleased
## Add
- add `RequestStream::inner_mut` method for accessing inner stream type.
- add `RequestStream::codec_mut` method for accessing `Codec` type.

# 0.2.0
## Add
Expand Down
10 changes: 10 additions & 0 deletions http-ws/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ where
}
}

#[inline]
pub fn inner_mut(&mut self) -> &mut S {
&mut self.stream
}

#[inline]
pub fn codec_mut(&mut self) -> &mut Codec {
&mut self.codec
}

/// Make a [ResponseStream] from current DecodeStream.
///
/// This API is to share the same codec for both decode and encode stream.
Expand Down
33 changes: 18 additions & 15 deletions test/tests/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,30 @@ async fn message_h2() -> Result<(), Error> {

let server_url = format!("wss://{}/", handle.ip_port_string());

let c = Client::new();
let (mut tx, mut rx) = c.ws2(&server_url)?.send().await?.split();
{
let c = Client::new();
let (mut tx, mut rx) = c.ws2(&server_url)?.send().await?.split();

for _ in 0..9 {
tx.send(Message::Text(Bytes::from("Hello,World!"))).await?;
}
for _ in 0..9 {
tx.send(Message::Text(Bytes::from("Hello,World!"))).await?;
}

for _ in 0..9 {
let msg = rx.next().await.unwrap()?;
assert_eq!(msg, Message::Text(Bytes::from("Hello,World!")));
}
for _ in 0..9 {
let msg = rx.next().await.unwrap()?;
assert_eq!(msg, Message::Text(Bytes::from("Hello,World!")));
}

tx.send(Message::Ping(Bytes::from("pingpong"))).await?;
let msg = rx.next().await.unwrap()?;
assert_eq!(msg, Message::Pong(Bytes::from("pingpong")));
tx.send(Message::Ping(Bytes::from("pingpong"))).await?;
let msg = rx.next().await.unwrap()?;
assert_eq!(msg, Message::Pong(Bytes::from("pingpong")));

tx.send(Message::Close(None)).await?;
let msg = rx.next().await.unwrap()?;
assert_eq!(msg, Message::Close(None));
tx.send(Message::Close(None)).await?;
let msg = rx.next().await.unwrap()?;
assert_eq!(msg, Message::Close(None));
}

handle.try_handle()?.stop(true);
tokio::task::yield_now().await;
handle.await.map_err(Into::into)
}

Expand Down

0 comments on commit b5986c0

Please sign in to comment.