From 7a1415574974885f6b36c7c3c3d98dc686faee92 Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Sat, 30 Dec 2023 19:10:07 +0900 Subject: [PATCH] feat(oio::read): implement the blocking buffer reader (#3860) * feat: implement blocking buffer reader * test: add tests for blocking buffer reader --- core/src/raw/oio/read/buffer_reader.rs | 361 +++++++++++++++++++++---- 1 file changed, 314 insertions(+), 47 deletions(-) diff --git a/core/src/raw/oio/read/buffer_reader.rs b/core/src/raw/oio/read/buffer_reader.rs index 24c2de17b4a..39da987153a 100644 --- a/core/src/raw/oio/read/buffer_reader.rs +++ b/core/src/raw/oio/read/buffer_reader.rs @@ -29,6 +29,8 @@ use std::task::Poll; use crate::raw::*; use crate::*; +use super::BlockingRead; + /// [BufferReader] allows the underlying reader to fetch data at the buffer's size /// and is used to amortize the IO's overhead. pub struct BufferReader { @@ -65,6 +67,31 @@ impl BufferReader { fn capacity(&self) -> usize { self.buf.capacity() } + + fn consume(&mut self, amt: usize) { + let new_pos = min(self.pos + amt, self.filled); + let amt = new_pos - self.pos; + + self.pos = new_pos; + self.cur += amt as u64; + } + + fn seek_relative(&mut self, offset: i64) -> Option { + let pos = self.pos as u64; + + if let (Some(new_pos), Some(new_cur)) = ( + pos.checked_add_signed(offset), + self.cur.checked_add_signed(offset), + ) { + if new_pos <= self.filled as u64 { + self.cur = new_cur; + self.pos = new_pos as usize; + return Some(self.cur); + } + } + + None + } } impl BufferReader @@ -95,31 +122,6 @@ where Poll::Ready(Ok(&self.buf[self.pos..self.filled])) } - fn consume(&mut self, amt: usize) { - let new_pos = min(self.pos + amt, self.filled); - let amt = new_pos - self.pos; - - self.pos = new_pos; - self.cur += amt as u64; - } - - fn seek_relative(&mut self, offset: i64) -> Option { - let pos = self.pos as u64; - - if let (Some(new_pos), Some(new_cur)) = ( - pos.checked_add_signed(offset), - self.cur.checked_add_signed(offset), - ) { - if new_pos <= self.filled as u64 { - self.cur = new_cur; - self.pos = new_pos as usize; - return Some(self.cur); - } - } - - None - } - fn poll_inner_seek(&mut self, cx: &mut Context<'_>, pos: SeekFrom) -> Poll> { let cur = ready!(self.r.poll_seek(cx, pos))?; self.discard_buffer(); @@ -192,23 +194,120 @@ where } } +impl BufferReader +where + R: BlockingRead, +{ + fn fill_buf(&mut self) -> Result<&[u8]> { + // If we've reached the end of our internal buffer then we need to fetch + // some more data from the underlying reader. + // Branch using `>=` instead of the more correct `==` + // to tell the compiler that the pos..cap slice is always valid. + if self.pos >= self.filled { + debug_assert!(self.pos == self.filled); + + let cap = self.capacity(); + self.buf.clear(); + let dst = self.buf.spare_capacity_mut(); + let mut buf = ReadBuf::uninit(dst); + unsafe { buf.assume_init(cap) }; + + let n = self.r.read(buf.initialized_mut())?; + unsafe { self.buf.set_len(n) } + + self.pos = 0; + self.filled = n; + } + + Ok(&self.buf[self.pos..self.filled]) + } + + fn inner_seek(&mut self, pos: SeekFrom) -> Result { + let cur = self.r.seek(pos)?; + self.discard_buffer(); + self.cur = cur; + + Ok(cur) + } +} + +impl BlockingRead for BufferReader +where + R: BlockingRead, +{ + fn read(&mut self, mut dst: &mut [u8]) -> Result { + // Sanity check for normal cases. + if dst.is_empty() { + return Ok(0); + } + + // If we don't have any buffered data and we're doing a massive read + // (larger than our internal buffer), bypass our internal buffer + // entirely. + if self.pos == self.filled && dst.len() >= self.capacity() { + let res = self.r.read(dst); + self.discard_buffer(); + return res; + } + + let rem = self.fill_buf()?; + let amt = min(rem.len(), dst.len()); + dst.put(&rem[..amt]); + self.consume(amt); + Ok(amt) + } + + fn seek(&mut self, pos: SeekFrom) -> Result { + match pos { + SeekFrom::Start(new_pos) => { + // TODO(weny): Check the overflowing. + let Some(offset) = (new_pos as i64).checked_sub(self.cur as i64) else { + return self.inner_seek(pos); + }; + + match self.seek_relative(offset) { + Some(cur) => Ok(cur), + None => self.inner_seek(pos), + } + } + SeekFrom::Current(offset) => match self.seek_relative(offset) { + Some(cur) => Ok(cur), + None => self.inner_seek(pos), + }, + SeekFrom::End(_) => self.inner_seek(pos), + } + } + + fn next(&mut self) -> Option> { + match self.fill_buf() { + Ok(bytes) => { + if bytes.is_empty() { + return None; + } + + let bytes = Bytes::copy_from_slice(bytes); + self.consume(bytes.len()); + Some(Ok(bytes)) + } + Err(err) => Some(Err(err)), + } + } +} + #[cfg(test)] mod tests { use std::io::SeekFrom; - use std::pin::Pin; use std::sync::Arc; + use crate::raw::oio::RangeReader; use async_trait::async_trait; use bytes::Bytes; - use futures::AsyncRead; use futures::AsyncReadExt; use futures::AsyncSeekExt; use rand::prelude::*; use sha2::Digest; use sha2::Sha256; - use crate::raw::oio::RangeReader; - use super::*; // Generate bytes between [4MiB, 16MiB) @@ -236,7 +335,7 @@ mod tests { #[async_trait] impl Accessor for MockReadService { type Reader = MockReader; - type BlockingReader = (); + type BlockingReader = MockReader; type Writer = (); type BlockingWriter = (); type Lister = (); @@ -258,22 +357,30 @@ mod tests { Ok(( RpRead::new(), MockReader { - inner: futures::io::Cursor::new(bs.into()), + inner: oio::Cursor::from(bs), + }, + )) + } + + fn blocking_read(&self, _: &str, args: OpRead) -> Result<(RpRead, Self::BlockingReader)> { + let bs = args.range().apply_on_bytes(self.data.clone()); + + Ok(( + RpRead::new(), + MockReader { + inner: oio::Cursor::from(bs), }, )) } } - #[derive(Debug, Clone, Default)] struct MockReader { - inner: futures::io::Cursor>, + inner: oio::Cursor, } impl oio::Read for MockReader { fn poll_read(&mut self, cx: &mut Context, buf: &mut [u8]) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf).map_err(|err| { - Error::new(ErrorKind::Unexpected, "read data from mock").set_source(err) - }) + self.inner.poll_read(cx, buf) } fn poll_seek(&mut self, cx: &mut Context<'_>, pos: SeekFrom) -> Poll> { @@ -286,17 +393,24 @@ mod tests { } fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { - let mut bs = vec![0; 4 * 1024]; - let n = ready!(Pin::new(&mut self.inner) - .poll_read(cx, &mut bs) - .map_err( - |err| Error::new(ErrorKind::Unexpected, "read data from mock").set_source(err) - )?); - if n == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from(bs[..n].to_vec())))) - } + self.inner.poll_next(cx) + } + } + + impl BlockingRead for MockReader { + fn read(&mut self, buf: &mut [u8]) -> Result { + self.inner.read(buf) + } + + fn seek(&mut self, _pos: SeekFrom) -> Result { + Err(Error::new( + ErrorKind::Unsupported, + "output reader doesn't support seeking", + )) + } + + fn next(&mut self) -> Option> { + self.inner.next() } } @@ -463,4 +577,157 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_blocking_read_from_buf() -> anyhow::Result<()> { + let bs = Bytes::copy_from_slice(&b"Hello, World!"[..]); + let r = Box::new(oio::Cursor::from(bs.clone())) as oio::BlockingReader; + let buf_cap = 10; + let mut r = Box::new(BufferReader::new(r, buf_cap)) as oio::BlockingReader; + + let mut dst = [0u8; 5]; + let nread = r.read(&mut dst)?; + assert_eq!(nread, dst.len()); + assert_eq!(&dst, b"Hello"); + + let mut dst = [0u8; 5]; + let nread = r.read(&mut dst)?; + assert_eq!(nread, dst.len()); + assert_eq!(&dst, b", Wor"); + + let mut dst = [0u8; 3]; + let nread = r.read(&mut dst)?; + assert_eq!(nread, dst.len()); + assert_eq!(&dst, b"ld!"); + + Ok(()) + } + + #[tokio::test] + async fn test_blocking_seek() -> anyhow::Result<()> { + let bs = Bytes::copy_from_slice(&b"Hello, World!"[..]); + let r = Box::new(oio::Cursor::from(bs.clone())) as oio::BlockingReader; + let buf_cap = 10; + let mut r = Box::new(BufferReader::new(r, buf_cap)) as oio::BlockingReader; + + // The underlying reader buffers the b"Hello, Wor". + let mut dst = [0u8; 5]; + let nread = r.read(&mut dst)?; + assert_eq!(nread, dst.len()); + assert_eq!(&dst, b"Hello"); + + let pos = r.seek(SeekFrom::Start(7))?; + assert_eq!(pos, 7); + let mut dst = [0u8; 5]; + let nread = r.read(&mut dst)?; + assert_eq!(&dst[..nread], &bs[7..10]); + assert_eq!(nread, 3); + + // Should perform a relative seek. + let pos = r.seek(SeekFrom::Start(0))?; + assert_eq!(pos, 0); + let mut dst = [0u8; 9]; + let nread = r.read(&mut dst)?; + assert_eq!(&dst[..nread], &bs[0..9]); + assert_eq!(nread, 9); + + // Should perform a non-relative seek. + let pos = r.seek(SeekFrom::Start(11))?; + assert_eq!(pos, 11); + let mut dst = [0u8; 9]; + let nread = r.read(&mut dst)?; + assert_eq!(&dst[..nread], &bs[11..13]); + assert_eq!(nread, 2); + + Ok(()) + } + + #[tokio::test] + async fn test_blocking_read_all() -> anyhow::Result<()> { + let (bs, _) = gen_bytes(); + let r = Box::new(oio::Cursor::from(bs.clone())) as oio::BlockingReader; + let mut r = Box::new(BufferReader::new(r, 4096 * 1024)) as oio::BlockingReader; + + let mut buf = Vec::new(); + r.read_to_end(&mut buf)?; + assert_eq!(bs.len(), buf.len(), "read size"); + assert_eq!( + format!("{:x}", Sha256::digest(&bs)), + format!("{:x}", Sha256::digest(&buf)), + "read content" + ); + + let n = r.seek(SeekFrom::Start(0))?; + assert_eq!(n, 0, "seek position must be 0"); + + let mut buf = Vec::new(); + r.read_to_end(&mut buf)?; + assert_eq!(bs.len(), buf.len(), "read twice size"); + assert_eq!( + format!("{:x}", Sha256::digest(&bs)), + format!("{:x}", Sha256::digest(&buf)), + "read twice content" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_blocking_read_part() -> anyhow::Result<()> { + use std::io::Read; + + let (bs, _) = gen_bytes(); + let acc = Arc::new(MockReadService::new(bs.clone())); + let r = Box::new(RangeReader::new( + acc, + "x", + OpRead::default().with_range(BytesRange::from(4096..4096 + 4096)), + )) as oio::BlockingReader; + let mut r = Box::new(BufferReader::new(r, 4096 * 1024)) as oio::BlockingReader; + + let mut buf = Vec::new(); + BlockingRead::read_to_end(&mut r, &mut buf)?; + assert_eq!(4096, buf.len(), "read size"); + assert_eq!( + format!("{:x}", Sha256::digest(&bs[4096..4096 + 4096])), + format!("{:x}", Sha256::digest(&buf)), + "read content" + ); + + let n = r.seek(SeekFrom::Start(0))?; + assert_eq!(n, 0, "seek position must be 0"); + + let mut buf = Vec::new(); + BlockingRead::read_to_end(&mut r, &mut buf)?; + assert_eq!(4096, buf.len(), "read twice size"); + assert_eq!( + format!("{:x}", Sha256::digest(&bs[4096..4096 + 4096])), + format!("{:x}", Sha256::digest(&buf)), + "read twice content" + ); + + let n = r.seek(SeekFrom::Start(1024))?; + assert_eq!(1024, n, "seek to 1024"); + + let mut buf = vec![0; 1024]; + r.read_exact(&mut buf)?; + assert_eq!( + format!("{:x}", Sha256::digest(&bs[4096 + 1024..4096 + 2048])), + format!("{:x}", Sha256::digest(&buf)), + "read after seek 1024" + ); + + let n = r.seek(SeekFrom::Current(1024))?; + assert_eq!(3072, n, "seek to 3072"); + + let mut buf = vec![0; 1024]; + r.read_exact(&mut buf)?; + assert_eq!( + format!("{:x}", Sha256::digest(&bs[4096 + 3072..4096 + 3072 + 1024])), + format!("{:x}", Sha256::digest(&buf)), + "read after seek to 3072" + ); + + Ok(()) + } }