From b845b7b9c4a6acdd09592bfd78b10361739f48b2 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:45:23 +0800 Subject: [PATCH] internal_send function --- src/async_session.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/async_session.rs b/src/async_session.rs index af9047e..ff01755 100644 --- a/src/async_session.rs +++ b/src/async_session.rs @@ -85,6 +85,10 @@ impl AsyncSession { } pub async fn send(&self, buf: &[u8]) -> std::io::Result { + self.internal_send(buf) + } + + fn internal_send(&self, buf: &[u8]) -> std::io::Result { let packet = self.session.allocate_send_packet(buf.len() as _)?; packet.bytes.copy_from_slice(buf); self.session.send_packet(packet); @@ -94,11 +98,15 @@ impl AsyncSession { impl AsyncRead for AsyncSession { fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + use std::io::{Error, ErrorKind::Other}; loop { match &mut self.read_state { ReadState::Idle => match self.session.try_receive() { Ok(Some(packet)) => { - let size = packet.bytes.len().min(buf.len()); + let size = packet.bytes.len(); + if buf.len() < size { + return Poll::Ready(Err(Error::new(Other, "Buffer too small"))); + } buf[..size].copy_from_slice(&packet.bytes[..size]); return Poll::Ready(Ok(size)); } @@ -122,8 +130,7 @@ impl AsyncRead for AsyncSession { Ok(guard) => guard, Err(e) => { self.read_state = ReadState::Waiting(Some(task)); - use std::io::{Error, ErrorKind}; - return Poll::Ready(Err(Error::new(ErrorKind::Other, format!("Lock task failed: {}", e)))); + return Poll::Ready(Err(Error::new(Other, format!("Lock task failed: {}", e)))); } }; self.read_state = match Pin::new(&mut *task_guard).poll(cx) { @@ -143,10 +150,7 @@ impl AsyncRead for AsyncSession { impl AsyncWrite for AsyncSession { fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - let packet = self.session.allocate_send_packet(buf.len() as _)?; - packet.bytes.copy_from_slice(buf); - self.session.send_packet(packet); - Poll::Ready(Ok(buf.len())) + Poll::Ready(Ok(self.internal_send(buf)?)) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> {