diff --git a/.typos.toml b/.typos.toml index 259a4133e51..5b2bca99b5a 100644 --- a/.typos.toml +++ b/.typos.toml @@ -20,5 +20,6 @@ "Dum" = "Dum" "ba" = "ba" "Hel" = "Hel" +"hellow" = "hellow" # Showed up in examples. "thw" = "thw" diff --git a/core/src/raw/oio/buf/chunked_bytes.rs b/core/src/raw/oio/buf/chunked_bytes.rs new file mode 100644 index 00000000000..82447dc62b4 --- /dev/null +++ b/core/src/raw/oio/buf/chunked_bytes.rs @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytes::{Bytes, BytesMut}; +use std::cmp::min; +use std::collections::VecDeque; +use std::io::IoSlice; +use std::task::{Context, Poll}; + +use crate::raw::*; +use crate::*; + +// TODO: 64KiB is picked based on experiences, should be configurable +const DEFAULT_CHUNK_SIZE: usize = 64 * 1024; + +/// ChunkedBytes is used represents a non-contiguous bytes in memory. +#[derive(Clone)] +pub struct ChunkedBytes { + chunk_size: usize, + + frozen: VecDeque, + active: BytesMut, + size: usize, +} + +impl Default for ChunkedBytes { + fn default() -> Self { + Self::new() + } +} + +impl ChunkedBytes { + /// Create a new chunked bytes. + pub fn new() -> Self { + Self { + frozen: VecDeque::new(), + active: BytesMut::new(), + size: 0, + + chunk_size: DEFAULT_CHUNK_SIZE, + } + } + + /// Create a new chunked cursor with given chunk size. + pub fn with_chunk_size(chunk_size: usize) -> Self { + Self { + frozen: VecDeque::new(), + active: BytesMut::new(), + size: 0, + + chunk_size, + } + } + + /// Build a chunked bytes from a vector of bytes. + /// + /// This function is guaranteed to run in O(1) time and to not re-allocate the Vec’s buffer + /// or allocate any additional memory. + /// + /// Reference: + pub fn from_vec(bs: Vec) -> Self { + Self { + frozen: bs.into(), + active: BytesMut::new(), + size: 0, + + chunk_size: DEFAULT_CHUNK_SIZE, + } + } + + /// Returns `true` if current cursor is empty. + pub fn is_empty(&self) -> bool { + self.size == 0 + } + + /// Return current bytes size of cursor. + pub fn len(&self) -> usize { + self.size + } + + /// Clear the entire cursor. + pub fn clear(&mut self) { + self.size = 0; + self.frozen.clear(); + self.active.clear(); + } + + /// Push a new bytes into ChunkedBytes. + pub fn push(&mut self, mut bs: Bytes) { + self.size += bs.len(); + + // Optimization: if active is empty, we can push to frozen directly if possible. + if self.active.is_empty() { + let aligned_size = bs.len() - bs.len() % self.chunk_size; + if aligned_size > 0 { + self.frozen.push_back(bs.split_to(aligned_size)); + } + if !bs.is_empty() { + self.active.extend_from_slice(&bs); + } + return; + } + + // Try to fill bytes into active first. + let remaining = self.chunk_size.saturating_sub(self.active.len()); + if remaining > 0 { + let len = min(remaining, bs.len()); + self.active.extend_from_slice(&bs.split_to(len)); + } + + // If active is full, freeze it and push it into frozen. + if self.active.len() == self.chunk_size { + self.frozen.push_back(self.active.split().freeze()); + } + + // Split remaining bytes into chunks. + let aligned_size = bs.len() - bs.len() % self.chunk_size; + if aligned_size > 0 { + self.frozen.push_back(bs.split_to(aligned_size)); + } + + // Append to active if there are remaining bytes. + if !bs.is_empty() { + self.active.extend_from_slice(&bs); + } + } + + /// Push a new &[u8] into ChunkedBytes. + pub fn extend_from_slice(&mut self, bs: &[u8]) { + self.size += bs.len(); + + let mut remaining = bs; + + while !remaining.is_empty() { + let available = self.chunk_size.saturating_sub(self.active.len()); + + // available == 0 means self.active.len() >= CHUNK_SIZE + if available == 0 { + self.frozen.push_back(self.active.split().freeze()); + self.active.reserve(self.chunk_size); + continue; + } + + let size = min(remaining.len(), available); + self.active.extend_from_slice(&remaining[0..size]); + + remaining = &remaining[size..]; + } + } + + /// Pull data from [`oio::WriteBuf`] into ChunkedBytes. + pub fn extend_from_write_buf(&mut self, size: usize, buf: &dyn oio::WriteBuf) -> usize { + let to_write = min(buf.chunk().len(), size); + + if buf.is_bytes_optimized(to_write) && to_write > self.chunk_size { + // If the chunk is optimized, we can just push it directly. + self.push(buf.bytes(to_write)); + } else { + // Otherwise, we should copy it into the buffer. + self.extend_from_slice(&buf.chunk()[..to_write]); + } + + to_write + } +} + +impl oio::WriteBuf for ChunkedBytes { + fn remaining(&self) -> usize { + self.size + } + + fn advance(&mut self, mut cnt: usize) { + debug_assert!( + cnt <= self.size, + "cnt size {} is larger than bytes size {}", + cnt, + self.size + ); + + self.size -= cnt; + + while cnt > 0 { + if let Some(front) = self.frozen.front_mut() { + if front.len() <= cnt { + cnt -= front.len(); + self.frozen.pop_front(); // Remove the entire chunk. + } else { + front.advance(cnt); // Split and keep the remaining part. + break; + } + } else { + // Here, cnt must be <= self.active.len() due to the checks above + self.active.advance(cnt); // Remove cnt bytes from the active buffer. + break; + } + } + } + + fn chunk(&self) -> &[u8] { + match self.frozen.front() { + Some(v) => v, + None => &self.active, + } + } + + fn vectored_chunk(&self) -> Vec { + let it = self.frozen.iter().map(|v| IoSlice::new(v)); + + if !self.active.is_empty() { + it.chain([IoSlice::new(&self.active)]).collect() + } else { + it.collect() + } + } + + fn bytes(&self, size: usize) -> Bytes { + debug_assert!( + size <= self.size, + "input size {} is larger than bytes size {}", + size, + self.size + ); + + if size == 0 { + return Bytes::new(); + } + + if let Some(bs) = self.frozen.front() { + if size <= bs.len() { + return bs.slice(..size); + } + } + + let mut remaining = size; + let mut result = BytesMut::with_capacity(size); + + // First, go through the frozen buffer. + for chunk in &self.frozen { + let to_copy = min(remaining, chunk.len()); + result.extend_from_slice(&chunk[0..to_copy]); + remaining -= to_copy; + + if remaining == 0 { + break; + } + } + + // Then, get from the active buffer if necessary. + if remaining > 0 { + result.extend_from_slice(&self.active[0..remaining]); + } + + result.freeze() + } + + fn is_bytes_optimized(&self, size: usize) -> bool { + if let Some(bs) = self.frozen.front() { + return size <= bs.len(); + } + + false + } + + fn vectored_bytes(&self, size: usize) -> Vec { + debug_assert!( + size <= self.size, + "input size {} is larger than bytes size {}", + size, + self.size + ); + + let mut remaining = size; + let mut buf = vec![]; + for bs in self.frozen.iter() { + if remaining == 0 { + break; + } + + let to_take = min(remaining, bs.len()); + + if to_take == bs.len() { + buf.push(bs.clone()); // Clone is shallow; no data copy occurs. + } else { + buf.push(bs.slice(0..to_take)); + } + + remaining -= to_take; + } + + if remaining > 0 { + buf.push(Bytes::copy_from_slice(&self.active[0..remaining])); + } + + buf + } +} + +impl oio::Stream for ChunkedBytes { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { + match self.frozen.pop_front() { + Some(bs) => { + self.size -= bs.len(); + Poll::Ready(Some(Ok(bs))) + } + None if !self.active.is_empty() => { + self.size -= self.active.len(); + Poll::Ready(Some(Ok(self.active.split().freeze()))) + } + None => Poll::Ready(None), + } + } + + fn poll_reset(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Err(Error::new( + ErrorKind::Unsupported, + "ChunkedBytes does not support reset", + ))) + } +} + +#[cfg(test)] +mod tests { + use log::debug; + use pretty_assertions::assert_eq; + use rand::{thread_rng, Rng, RngCore}; + use sha2::{Digest, Sha256}; + + use super::*; + use crate::raw::oio::WriteBuf; + + #[test] + fn test_chunked_bytes_write_buf() -> Result<()> { + let mut c = ChunkedBytes::with_chunk_size(5); + + c.push(Bytes::from("hello")); + assert_eq!(c.len(), 5); + assert!(!c.is_empty()); + + c.push(Bytes::from("world")); + assert_eq!(c.len(), 10); + assert!(!c.is_empty()); + + // Test chunk + let bs = c.chunk(); + assert_eq!(bs, "hello".as_bytes()); + assert_eq!(c.len(), 10); + assert!(!c.is_empty()); + + // The second chunk should return the same content. + let bs = c.chunk(); + assert_eq!(bs, "hello".as_bytes()); + assert_eq!(c.remaining(), 10); + assert!(!c.is_empty()); + + // Test vectored chunk + let bs = c.vectored_chunk(); + assert_eq!( + bs.iter().map(|v| v.as_ref()).collect::>(), + vec!["hello".as_bytes(), "world".as_bytes()] + ); + assert_eq!(c.remaining(), 10); + assert!(!c.is_empty()); + + // Test bytes + let bs = c.bytes(4); + assert_eq!(bs, Bytes::from("hell")); + assert_eq!(c.remaining(), 10); + assert!(!c.is_empty()); + + // Test bytes again + let bs = c.bytes(6); + assert_eq!(bs, Bytes::from("hellow")); + assert_eq!(c.remaining(), 10); + assert!(!c.is_empty()); + + // Test vectored bytes + let bs = c.vectored_bytes(4); + assert_eq!(bs, vec![Bytes::from("hell")]); + assert_eq!(c.remaining(), 10); + assert!(!c.is_empty()); + + // Test vectored bytes again + let bs = c.vectored_bytes(6); + assert_eq!(bs, vec![Bytes::from("hello"), Bytes::from("w")]); + assert_eq!(c.remaining(), 10); + assert!(!c.is_empty()); + + // Test Advance. + c.advance(4); + + // Test chunk + let bs = c.chunk(); + assert_eq!(bs, "o".as_bytes()); + assert_eq!(c.len(), 6); + assert!(!c.is_empty()); + + c.clear(); + assert_eq!(c.len(), 0); + assert!(c.is_empty()); + + Ok(()) + } + + #[test] + fn test_fuzz_chunked_bytes_push() -> Result<()> { + let _ = tracing_subscriber::fmt() + .pretty() + .with_test_writer() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut rng = thread_rng(); + + let chunk_size = rng.gen_range(1..10); + let mut cb = ChunkedBytes::with_chunk_size(chunk_size); + debug!("test_fuzz_chunked_bytes_push: chunk size: {chunk_size}"); + + let mut expected = BytesMut::new(); + for _ in 0..1000 { + let size = rng.gen_range(1..20); + debug!("test_fuzz_chunked_bytes_push: write size: {size}"); + + let mut content = vec![0; size]; + rng.fill_bytes(&mut content); + + expected.extend_from_slice(&content); + cb.push(Bytes::from(content.clone())); + + let cnt = rng.gen_range(0..expected.len()); + expected.advance(cnt); + cb.advance(cnt); + + assert_eq!(expected.len(), cb.len()); + assert_eq!( + format!("{:x}", Sha256::digest(&expected)), + format!("{:x}", Sha256::digest(&cb.bytes(cb.len()))) + ); + } + + Ok(()) + } + + #[test] + fn test_fuzz_chunked_bytes_extend_from_slice() -> Result<()> { + let _ = tracing_subscriber::fmt() + .pretty() + .with_test_writer() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut rng = thread_rng(); + + let chunk_size = rng.gen_range(1..10); + let mut cb = ChunkedBytes::with_chunk_size(chunk_size); + debug!("test_fuzz_chunked_bytes_extend_from_slice: chunk size: {chunk_size}"); + + let mut expected = BytesMut::new(); + for _ in 0..1000 { + let size = rng.gen_range(1..20); + debug!("test_fuzz_chunked_bytes_extend_from_slice: write size: {size}"); + + let mut content = vec![0; size]; + rng.fill_bytes(&mut content); + + expected.extend_from_slice(&content); + cb.extend_from_slice(&content); + + let cnt = rng.gen_range(0..expected.len()); + expected.advance(cnt); + cb.advance(cnt); + + assert_eq!(expected.len(), cb.len()); + assert_eq!( + format!("{:x}", Sha256::digest(&expected)), + format!("{:x}", Sha256::digest(&cb.bytes(cb.len()))) + ); + } + + Ok(()) + } +} diff --git a/core/src/raw/oio/buf/mod.rs b/core/src/raw/oio/buf/mod.rs new file mode 100644 index 00000000000..dfd3663e56f --- /dev/null +++ b/core/src/raw/oio/buf/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod chunked_bytes; +pub use chunked_bytes::ChunkedBytes; + +mod write_buf; +pub use write_buf::WriteBuf; diff --git a/core/src/raw/oio/buf.rs b/core/src/raw/oio/buf/write_buf.rs similarity index 90% rename from core/src/raw/oio/buf.rs rename to core/src/raw/oio/buf/write_buf.rs index 50f3c8911bb..d70ce640a0b 100644 --- a/core/src/raw/oio/buf.rs +++ b/core/src/raw/oio/buf/write_buf.rs @@ -74,6 +74,22 @@ pub trait WriteBuf: Send + Sync { /// This function will panic if size > self.remaining(). fn bytes(&self, size: usize) -> Bytes; + /// Returns true if the underlying buffer is optimized for bytes with given size. + /// + /// # Notes + /// + /// This function is used to avoid copy when possible. Implementors should return `true` + /// the given `self.bytes(size)` could be done without cost. For example, the underlying + /// buffer is `Bytes`. + /// + /// # Panics + /// + /// This function will panic if size > self.remaining(). + fn is_bytes_optimized(&self, size: usize) -> bool { + let _ = size; + false + } + /// Returns a vectored bytes of the underlying buffer at the current position and of /// length between 0 and Buf::remaining(). /// @@ -114,6 +130,10 @@ macro_rules! deref_forward_buf { (**self).bytes(size) } + fn is_bytes_optimized(&self, size: usize) -> bool { + (**self).is_bytes_optimized(size) + } + fn vectored_bytes(&self, size: usize) -> Vec { (**self).vectored_bytes(size) } @@ -231,6 +251,11 @@ impl WriteBuf for Bytes { self.slice(..size) } + #[inline] + fn is_bytes_optimized(&self, _: usize) -> bool { + true + } + #[inline] fn vectored_bytes(&self, size: usize) -> Vec { vec![self.slice(..size)] diff --git a/core/src/raw/oio/cursor.rs b/core/src/raw/oio/cursor.rs index 2c79451c13e..ddead52c9d5 100644 --- a/core/src/raw/oio/cursor.rs +++ b/core/src/raw/oio/cursor.rs @@ -22,9 +22,7 @@ use std::io::SeekFrom; use std::task::Context; use std::task::Poll; -use bytes::Buf; use bytes::Bytes; -use bytes::BytesMut; use crate::raw::*; use crate::*; @@ -318,7 +316,7 @@ impl ChunkedCursor { #[cfg(test)] fn concat(&self) -> Bytes { - let mut bs = BytesMut::new(); + let mut bs = bytes::BytesMut::new(); for v in self.inner.iter().skip(self.idx) { bs.extend_from_slice(v); } @@ -343,166 +341,6 @@ impl oio::Stream for ChunkedCursor { } } -/// VectorCursor is the cursor for [`Vec`] that implements [`oio::Stream`] -pub struct VectorCursor { - inner: VecDeque, - size: usize, -} - -impl Default for VectorCursor { - fn default() -> Self { - Self::new() - } -} - -impl VectorCursor { - /// Create a new vector cursor. - pub fn new() -> Self { - Self { - inner: VecDeque::new(), - size: 0, - } - } - - /// Returns `true` if current vector is empty. - pub fn is_empty(&self) -> bool { - self.size == 0 - } - - /// Return current bytes size of current vector. - pub fn len(&self) -> usize { - self.size - } - - /// Push a new bytes into vector cursor. - pub fn push(&mut self, bs: Bytes) { - self.size += bs.len(); - self.inner.push_back(bs); - } - - /// Pop a bytes from vector cursor. - pub fn pop(&mut self) { - let bs = self.inner.pop_back(); - self.size -= bs.expect("pop bytes must exist").len() - } - - /// Clear the entire vector. - pub fn clear(&mut self) { - self.inner.clear(); - self.size = 0; - } - - /// Peak will read and copy exactly n bytes from current cursor - /// without change it's content. - /// - /// This function is useful if you want to read a fixed size - /// content to make sure it aligned. - /// - /// # Panics - /// - /// Panics if n is larger than current size. - /// - /// # TODO - /// - /// Optimize to avoid data copy. - pub fn peak_exact(&self, n: usize) -> Bytes { - assert!(n <= self.size, "peak size must smaller than current size"); - - // Avoid data copy if n is smaller than first chunk. - if self.inner[0].len() >= n { - return self.inner[0].slice(..n); - } - - let mut bs = BytesMut::with_capacity(n); - let mut n = n; - for b in &self.inner { - if n == 0 { - break; - } - let len = b.len().min(n); - bs.extend_from_slice(&b[..len]); - n -= len; - } - bs.freeze() - } - - /// peak_at_least will read and copy at least n bytes from current - /// cursor without change it's content. - /// - /// This function is useful if you only want to make sure the - /// returning bytes is larger. - /// - /// # Panics - /// - /// Panics if n is larger than current size. - /// - /// # TODO - /// - /// Optimize to avoid data copy. - pub fn peak_at_least(&self, n: usize) -> Bytes { - assert!(n <= self.size, "peak size must smaller than current size"); - - // Avoid data copy if n is smaller than first chunk. - if self.inner[0].len() >= n { - return self.inner[0].clone(); - } - - let mut bs = BytesMut::with_capacity(n); - let mut n = n; - for b in &self.inner { - if n == 0 { - break; - } - let len = b.len().min(n); - bs.extend_from_slice(&b[..len]); - n -= len; - } - bs.freeze() - } - - /// peak all will read and copy all bytes from current cursor - /// without change it's content. - /// - /// TODO: we need to find a way to avoid copy all content here. - pub fn peak_all(&self) -> Bytes { - // Avoid data copy if we only have one bytes. - if self.inner.len() == 1 { - return self.inner[0].clone(); - } - - let mut bs = BytesMut::with_capacity(self.len()); - for b in &self.inner { - bs.extend_from_slice(b); - } - bs.freeze() - } - - /// Take will consume n bytes from current cursor. - /// - /// # Panics - /// - /// Panics if n is larger than current size. - pub fn take(&mut self, n: usize) { - assert!(n <= self.size, "take size must smamller than current size"); - - // Update current size - self.size -= n; - - let mut n = n; - while n > 0 { - assert!(!self.inner.is_empty(), "inner must not be empty"); - - if self.inner[0].len() <= n { - n -= self.inner[0].len(); - self.inner.pop_front(); - } else { - self.inner[0].advance(n); - n = 0; - } - } - } -} - #[cfg(test)] mod tests { use pretty_assertions::assert_eq; @@ -515,26 +353,6 @@ mod tests { use super::*; use crate::raw::oio::StreamExt; - #[test] - fn test_vector_cursor() { - let mut vc = VectorCursor::new(); - - vc.push(Bytes::from("hello")); - vc.push(Bytes::from("world")); - - assert_eq!(vc.peak_exact(1), Bytes::from("h")); - assert_eq!(vc.peak_exact(1), Bytes::from("h")); - assert_eq!(vc.peak_exact(4), Bytes::from("hell")); - assert_eq!(vc.peak_exact(10), Bytes::from("helloworld")); - - vc.take(1); - assert_eq!(vc.peak_exact(1), Bytes::from("e")); - vc.take(1); - assert_eq!(vc.peak_exact(1), Bytes::from("l")); - vc.take(5); - assert_eq!(vc.peak_exact(1), Bytes::from("r")); - } - #[tokio::test] async fn test_chunked_cursor() -> Result<()> { let mut c = ChunkedCursor::new(); diff --git a/core/src/raw/oio/mod.rs b/core/src/raw/oio/mod.rs index 8c353ca50d9..29cf8e474b6 100644 --- a/core/src/raw/oio/mod.rs +++ b/core/src/raw/oio/mod.rs @@ -37,10 +37,9 @@ pub use page::*; mod cursor; pub use cursor::ChunkedCursor; pub use cursor::Cursor; -pub use cursor::VectorCursor; mod entry; pub use entry::Entry; mod buf; -pub use buf::WriteBuf; +pub use buf::*; diff --git a/core/src/raw/oio/write/exact_buf_write.rs b/core/src/raw/oio/write/exact_buf_write.rs index 144d50ff0e5..02ab9d92657 100644 --- a/core/src/raw/oio/write/exact_buf_write.rs +++ b/core/src/raw/oio/write/exact_buf_write.rs @@ -15,14 +15,11 @@ // specific language governing permissions and limitations // under the License. -use std::cmp::min; use std::task::ready; use std::task::Context; use std::task::Poll; use async_trait::async_trait; -use bytes::Bytes; -use bytes::BytesMut; use crate::raw::oio::WriteBuf; use crate::raw::*; @@ -44,7 +41,7 @@ pub struct ExactBufWriter { /// The size for buffer, we will flush the underlying storage at the size of this buffer. buffer_size: usize, - buffer: Buffer, + buffer: oio::ChunkedBytes, } impl ExactBufWriter { @@ -53,83 +50,33 @@ impl ExactBufWriter { Self { inner, buffer_size, - buffer: Buffer::Empty, + buffer: oio::ChunkedBytes::default(), } } } -enum Buffer { - Empty, - Filling(BytesMut), - Consuming(Bytes), -} - #[async_trait] impl oio::Write for ExactBufWriter { fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn WriteBuf) -> Poll> { - loop { - match &mut self.buffer { - Buffer::Empty => { - if bs.remaining() >= self.buffer_size { - self.buffer = Buffer::Consuming(bs.bytes(self.buffer_size)); - return Poll::Ready(Ok(self.buffer_size)); - } - - let chunk = bs.chunk(); - let mut fill = BytesMut::with_capacity(chunk.len()); - fill.extend_from_slice(chunk); - self.buffer = Buffer::Filling(fill); - return Poll::Ready(Ok(chunk.len())); - } - Buffer::Filling(fill) => { - if fill.len() >= self.buffer_size { - self.buffer = Buffer::Consuming(fill.split().freeze()); - continue; - } - - let size = min(self.buffer_size - fill.len(), bs.chunk().len()); - fill.extend_from_slice(&bs.chunk()[..size]); - return Poll::Ready(Ok(size)); - } - Buffer::Consuming(consume) => { - // Make sure filled buffer has been flushed. - // - // TODO: maybe we can re-fill it after a successful write. - while !consume.is_empty() { - let n = ready!(self.inner.poll_write(cx, consume)?); - consume.advance(n); - } - self.buffer = Buffer::Empty; - } - } + if self.buffer.len() >= self.buffer_size { + let written = ready!(self.inner.poll_write(cx, &self.buffer)?); + self.buffer.advance(written); } + + let remaining = self.buffer_size - self.buffer.len(); + let written = self.buffer.extend_from_write_buf(remaining, bs); + Poll::Ready(Ok(written)) } fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.buffer = Buffer::Empty; + self.buffer.clear(); self.inner.poll_abort(cx) } fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.buffer { - Buffer::Empty => break, - Buffer::Filling(fill) => { - self.buffer = Buffer::Consuming(fill.split().freeze()); - continue; - } - Buffer::Consuming(consume) => { - // Make sure filled buffer has been flushed. - // - // TODO: maybe we can re-fill it after a successful write. - while !consume.is_empty() { - let n = ready!(self.inner.poll_write(cx, &consume))?; - consume.advance(n); - } - self.buffer = Buffer::Empty; - break; - } - } + while !self.buffer.is_empty() { + let n = ready!(self.inner.poll_write(cx, &self.buffer))?; + self.buffer.advance(n); } self.inner.poll_close(cx) @@ -138,6 +85,7 @@ impl oio::Write for ExactBufWriter { #[cfg(test)] mod tests { + use bytes::Bytes; use log::debug; use pretty_assertions::assert_eq; use rand::thread_rng; @@ -231,7 +179,7 @@ mod tests { let mut bs = Bytes::from(content.clone()); while !bs.is_empty() { let n = writer.write(&bs).await?; - bs.advance(n as usize); + bs.advance(n); } } writer.close().await?;