Skip to content

feat(io): reimplement vectored extensions #364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compio-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub mod compat;
mod read;
mod split;
pub mod util;
mod vectored;
mod write;

pub(crate) type IoResult<T> = std::io::Result<T>;
Expand Down
93 changes: 46 additions & 47 deletions compio-io/src/read/ext.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#[cfg(feature = "allocator_api")]
use std::alloc::Allocator;

use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, t_alloc};
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, SetBufInit, t_alloc};

use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take};
use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take, vectored::VectoredWrap};

/// Shared code for read a scalar value from the underlying reader.
macro_rules! read_scalar {
Expand Down Expand Up @@ -36,63 +36,35 @@ macro_rules! read_scalar {

/// Shared code for loop reading until reaching a certain length.
macro_rules! loop_read_exact {
($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
let mut $tracker = 0;
($buf:ident, $len:expr, $tracker:ident, $read_expr:expr, $update_expr:expr, $buf_expr:expr) => {
let mut $tracker = 0usize;
let len = $len;

while $tracker < len {
match $read_expr.await.into_inner() {
BufResult(Ok(0), buf) => {
let BufResult(res, buf) = $read_expr;
$buf = buf;
match res {
Ok(0) => {
return BufResult(
Err(::std::io::Error::new(
::std::io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
)),
buf,
$buf_expr,
);
}
BufResult(Ok(n), buf) => {
$tracker += n;
$buf = buf;
}
BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
$buf = buf;
Ok(n) => {
$tracker += n as usize;
$update_expr;
}
BufResult(Err(e), buf) => return BufResult(Err(e), buf),
Err(ref e) if e.kind() == ::std::io::ErrorKind::Interrupted => {}
Err(e) => return BufResult(Err(e), $buf_expr),
}
}
return BufResult(Ok(()), $buf)
return BufResult(Ok(()), $buf_expr)
};
}

macro_rules! loop_read_vectored {
($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
use ::compio_buf::OwnedIterator;

let mut $iter = match $buf.owned_iter() {
Ok(buf) => buf,
Err(buf) => return BufResult(Ok(()), buf),
};
let mut $tracker: $tracker_ty = 0;

loop {
let len = $iter.buf_capacity();
if len > 0 {
match $read_expr.await {
BufResult(Ok(()), ret) => {
$iter = ret;
$tracker += len as $tracker_ty;
}
BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
};
}

match $iter.next() {
Ok(next) => $iter = next,
Err(buf) => return BufResult(Ok(()), buf),
}
}
}};
($buf:ident, $iter:ident, $read_expr:expr) => {{
use ::compio_buf::OwnedIterator;

Expand Down Expand Up @@ -158,7 +130,14 @@ pub trait AsyncReadExt: AsyncRead {

/// Read the exact number of bytes required to fill the buf.
async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
loop_read_exact!(
buf,
buf.buf_capacity(),
read,
self.read(buf.slice(read..)).await.into_inner(),
{},
buf
);
}

/// Read all bytes until underlying reader reaches `EOF`.
Expand All @@ -171,7 +150,15 @@ pub trait AsyncReadExt: AsyncRead {

/// Read the exact number of bytes required to fill the vectored buf.
async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<(), T> {
loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter))
let mut buf = VectoredWrap::new(buf);
loop_read_exact!(
buf,
buf.capacity(),
read,
self.read_vectored(buf).await,
unsafe { buf.set_buf_init(read) },
buf.into_inner()
);
}

/// Creates an adaptor which reads at most `limit` bytes from it.
Expand Down Expand Up @@ -234,7 +221,11 @@ pub trait AsyncReadAtExt: AsyncReadAt {
buf,
buf.buf_capacity(),
read,
loop self.read_at(buf.slice(read..), pos + read as u64)
self.read_at(buf.slice(read..), pos + read as u64)
.await
.into_inner(),
{},
buf
);
}

Expand Down Expand Up @@ -262,7 +253,15 @@ pub trait AsyncReadAtExt: AsyncReadAt {
buf: T,
pos: u64,
) -> BufResult<(), T> {
loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total))
let mut buf = VectoredWrap::new(buf);
loop_read_exact!(
buf,
buf.capacity(),
read,
self.read_vectored_at(buf, pos + read as u64).await,
unsafe { buf.set_buf_init(read) },
buf.into_inner()
);
}
}

Expand Down
153 changes: 153 additions & 0 deletions compio-io/src/vectored.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::pin::Pin;

use compio_buf::{
Indexable, IndexableMut, IndexedIter, IntoInner, IoBuf, IoBufMut, IoVectoredBuf,
IoVectoredBufMut, MaybeOwned, MaybeOwnedMut, SetBufInit,
};

pub struct VectoredWrap<T> {
buffers: Pin<Box<T>>,
wraps: Vec<BufWrap>,
vec_off: usize,
}

impl<T: IoVectoredBuf> VectoredWrap<T> {
pub fn new(buffers: T) -> Self {
let buffers = Box::pin(buffers);
let wraps = buffers.iter_buf().map(|buf| BufWrap::new(&*buf)).collect();
Self {
buffers,
wraps,
vec_off: 0,
}
}

pub fn len(&self) -> usize {
self.wraps.iter().map(|buf| buf.len).sum()
}

pub fn capacity(&self) -> usize {
self.wraps.iter().map(|buf| buf.capacity).sum()
}
}

impl<T: IoVectoredBuf + 'static> IoVectoredBuf for VectoredWrap<T> {
type Buf = BufWrap;
type OwnedIter = IndexedIter<Self>;

fn iter_buf(&self) -> impl Iterator<Item = MaybeOwned<'_, Self::Buf>> {
self.wraps
.iter()
.skip(self.vec_off)
.map(MaybeOwned::Borrowed)
}

fn owned_iter(self) -> Result<Self::OwnedIter, Self>
where
Self: Sized,
{
IndexedIter::new(self)
}
}

impl<T: IoVectoredBufMut + 'static> IoVectoredBufMut for VectoredWrap<T> {
fn iter_buf_mut(&mut self) -> impl Iterator<Item = MaybeOwnedMut<'_, Self::Buf>> {
self.wraps
.iter_mut()
.skip(self.vec_off)
.map(MaybeOwnedMut::Borrowed)
}
}

impl<T: SetBufInit> SetBufInit for VectoredWrap<T> {
unsafe fn set_buf_init(&mut self, mut len: usize) {
self.buffers.as_mut().get_unchecked_mut().set_buf_init(len);
self.vec_off = 0;
for buf in self.wraps.iter_mut().skip(self.vec_off) {
let capacity = (*buf).buf_capacity();
let buf_new_len = len.min(capacity);
buf.set_buf_init(buf_new_len);
*buf = buf.offset(buf_new_len);
if len >= capacity {
len -= capacity;
} else {
break;
}
self.vec_off += 1;
}
}
}

impl<T> Indexable for VectoredWrap<T> {
type Output = BufWrap;

fn index(&self, n: usize) -> Option<&Self::Output> {
self.wraps.get(n + self.vec_off)
}
}

impl<T> IndexableMut for VectoredWrap<T> {
fn index_mut(&mut self, n: usize) -> Option<&mut Self::Output> {
self.wraps.get_mut(n + self.vec_off)
}
}

impl<T> IntoInner for VectoredWrap<T> {
type Inner = T;

fn into_inner(self) -> Self::Inner {
// Safety: no pointers still maintaining
*unsafe { Pin::into_inner_unchecked(self.buffers) }
}
}

pub struct BufWrap {
ptr: *mut u8,
len: usize,
capacity: usize,
}

impl BufWrap {
fn new<T: IoBuf>(buf: &T) -> Self {
Self {
ptr: buf.as_buf_ptr().cast_mut(),
len: buf.buf_len(),
capacity: buf.buf_capacity(),
}
}

fn offset(&self, off: usize) -> Self {
Self {
ptr: unsafe { self.ptr.add(off) },
len: self.len.saturating_sub(off),
capacity: self.capacity.saturating_sub(off),
}
}
}

unsafe impl IoBuf for BufWrap {
fn as_buf_ptr(&self) -> *const u8 {
self.ptr.cast_const()
}

fn buf_len(&self) -> usize {
self.len
}

fn buf_capacity(&self) -> usize {
self.capacity
}
}

unsafe impl IoBufMut for BufWrap {
fn as_buf_mut_ptr(&mut self) -> *mut u8 {
self.ptr
}
}

impl SetBufInit for BufWrap {
unsafe fn set_buf_init(&mut self, len: usize) {
debug_assert!(len <= self.capacity, "{} > {}", len, self.capacity);
self.len = self.len.max(len);
}
}
Loading
Loading