diff --git a/core/src/raw/oio/write/multipart_upload_write.rs b/core/src/raw/oio/write/multipart_upload_write.rs index 7124554ff81..67f18ecf3f2 100644 --- a/core/src/raw/oio/write/multipart_upload_write.rs +++ b/core/src/raw/oio/write/multipart_upload_write.rs @@ -21,6 +21,7 @@ use std::task::Context; use std::task::Poll; use async_trait::async_trait; +use bytes::Bytes; use futures::future::BoxFuture; use crate::raw::*; @@ -37,8 +38,26 @@ use crate::*; /// - Services impl `MultipartUploadWrite` /// - `MultipartUploadWriter` impl `Write` /// - Expose `MultipartUploadWriter` as `Accessor::Writer` +/// +/// # Notes +/// +/// `MultipartUploadWrite` has an oneshot optimization when `write` has been called only once: +/// +/// ```no_build +/// w.write(bs).await?; +/// w.close().await?; +/// ``` +/// +/// We will use `write_once` instead of starting a new multipart upload. #[async_trait] pub trait MultipartUploadWrite: Send + Sync + Unpin + 'static { + /// write_once is used to write the data to underlying storage at once. + /// + /// MultipartUploadWriter will call this API when: + /// + /// - All the data has been written to the buffer and we can perform the upload at once. + async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>; + /// initiate_part will call start a multipart upload and return the upload id. /// /// MultipartUploadWriter will call this when: @@ -90,6 +109,7 @@ pub struct MultipartUploadPart { pub struct MultipartUploadWriter { state: State, + cache: Option, upload_id: Option>, parts: Vec, } @@ -97,7 +117,7 @@ pub struct MultipartUploadWriter { enum State { Idle(Option), Init(BoxFuture<'static, (W, Result)>), - Write(BoxFuture<'static, (W, usize, Result)>), + Write(BoxFuture<'static, (W, Result)>), Close(BoxFuture<'static, (W, Result<()>)>), Abort(BoxFuture<'static, (W, Result<()>)>), } @@ -113,6 +133,7 @@ impl MultipartUploadWriter { Self { state: State::Idle(Some(inner)), + cache: None, upload_id: None, parts: Vec::new(), } @@ -128,15 +149,15 @@ where loop { match &mut self.state { State::Idle(w) => { - let w = w.take().expect("writer must be valid"); match self.upload_id.as_ref() { Some(upload_id) => { - let size = bs.remaining(); - let bs = bs.copy_to_bytes(size); let upload_id = upload_id.clone(); let part_number = self.parts.len(); + let bs = self.cache.clone().expect("cache must be valid").clone(); + let w = w.take().expect("writer must be valid"); self.state = State::Write(Box::pin(async move { + let size = bs.len(); let part = w .write_part( &upload_id, @@ -146,10 +167,18 @@ where ) .await; - (w, size, part) + (w, part) })); } None => { + // Fill cache with the first write. + if self.cache.is_none() { + let size = bs.remaining(); + self.cache = Some(bs.copy_to_bytes(size)); + return Poll::Ready(Ok(size)); + } + + let w = w.take().expect("writer must be valid"); self.state = State::Init(Box::pin(async move { let upload_id = w.initiate_part().await; (w, upload_id) @@ -163,10 +192,12 @@ where self.upload_id = Some(Arc::new(upload_id?)); } State::Write(fut) => { - let (w, size, part) = ready!(fut.as_mut().poll(cx)); + let (w, part) = ready!(fut.as_mut().poll(cx)); self.state = State::Idle(Some(w)); - self.parts.push(part?); + // Replace the cache when last write succeeded + let size = bs.remaining(); + self.cache = Some(bs.copy_to_bytes(size)); return Poll::Ready(Ok(size)); } State::Close(_) => { @@ -191,25 +222,57 @@ where match self.upload_id.clone() { Some(upload_id) => { let parts = self.parts.clone(); - self.state = State::Close(Box::pin(async move { - let res = w.complete_part(&upload_id, &parts).await; - (w, res) - })); + match self.cache.clone() { + Some(bs) => { + let upload_id = upload_id.clone(); + self.state = State::Write(Box::pin(async move { + let size = bs.len(); + let part = w + .write_part( + &upload_id, + parts.len(), + size as u64, + AsyncBody::Bytes(bs), + ) + .await; + (w, part) + })); + } + None => { + self.state = State::Close(Box::pin(async move { + let res = w.complete_part(&upload_id, &parts).await; + (w, res) + })); + } + } } - None => return Poll::Ready(Ok(())), + None => match self.cache.clone() { + Some(bs) => { + self.state = State::Close(Box::pin(async move { + let size = bs.len(); + let res = w.write_once(size as u64, AsyncBody::Bytes(bs)).await; + (w, res) + })); + } + None => return Poll::Ready(Ok(())), + }, } } State::Close(fut) => { let (w, res) = futures::ready!(fut.as_mut().poll(cx)); self.state = State::Idle(Some(w)); + self.cache = None; return Poll::Ready(res); } State::Init(_) => unreachable!( "MultipartUploadWriter must not go into State::Init during poll_close" ), - State::Write(_) => unreachable!( - "MultipartUploadWriter must not go into State::Write during poll_close" - ), + State::Write(fut) => { + let (w, part) = ready!(fut.as_mut().poll(cx)); + self.state = State::Idle(Some(w)); + self.parts.push(part?); + self.cache = None; + } State::Abort(_) => unreachable!( "MultipartUploadWriter must not go into State::Abort during poll_close" ), @@ -229,7 +292,10 @@ where (w, res) })); } - None => return Poll::Ready(Ok(())), + None => { + self.cache = None; + return Poll::Ready(Ok(())); + } } } State::Abort(fut) => { diff --git a/core/src/services/cos/backend.rs b/core/src/services/cos/backend.rs index 46e3b3064e5..eced1ccc95e 100644 --- a/core/src/services/cos/backend.rs +++ b/core/src/services/cos/backend.rs @@ -337,11 +337,9 @@ impl Accessor for CosBackend { let writer = CosWriter::new(self.core.clone(), path, args.clone()); let w = if args.append() { - CosWriters::Three(oio::AppendObjectWriter::new(writer)) - } else if args.content_length().is_some() { - CosWriters::One(oio::OneShotWriter::new(writer)) + CosWriters::Two(oio::AppendObjectWriter::new(writer)) } else { - CosWriters::Two(oio::MultipartUploadWriter::new(writer)) + CosWriters::One(oio::MultipartUploadWriter::new(writer)) }; let w = if let Some(buffer_size) = args.buffer_size() { diff --git a/core/src/services/cos/writer.rs b/core/src/services/cos/writer.rs index fba7cea09bd..11681b99ad6 100644 --- a/core/src/services/cos/writer.rs +++ b/core/src/services/cos/writer.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use async_trait::async_trait; -use bytes::Bytes; use http::StatusCode; use super::core::*; @@ -26,11 +25,8 @@ use super::error::parse_error; use crate::raw::*; use crate::*; -pub type CosWriters = oio::ThreeWaysWriter< - oio::OneShotWriter, - oio::MultipartUploadWriter, - oio::AppendObjectWriter, ->; +pub type CosWriters = + oio::TwoWaysWriter, oio::AppendObjectWriter>; pub struct CosWriter { core: Arc, @@ -50,16 +46,15 @@ impl CosWriter { } #[async_trait] -impl oio::OneShotWrite for CosWriter { - async fn write_once(&self, buf: Bytes) -> Result<()> { - let size = buf.len(); +impl oio::MultipartUploadWrite for CosWriter { + async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self.core.cos_put_object_request( &self.path, - Some(size as u64), + Some(size), self.op.content_type(), self.op.content_disposition(), self.op.cache_control(), - AsyncBody::Bytes(buf), + body, )?; self.core.sign(&mut req).await?; @@ -76,10 +71,7 @@ impl oio::OneShotWrite for CosWriter { _ => Err(parse_error(resp).await?), } } -} -#[async_trait] -impl oio::MultipartUploadWrite for CosWriter { async fn initiate_part(&self) -> Result { let resp = self .core diff --git a/core/src/services/obs/backend.rs b/core/src/services/obs/backend.rs index cc14eedd02f..9cca757a856 100644 --- a/core/src/services/obs/backend.rs +++ b/core/src/services/obs/backend.rs @@ -375,11 +375,9 @@ impl Accessor for ObsBackend { let writer = ObsWriter::new(self.core.clone(), path, args.clone()); let w = if args.append() { - ObsWriters::Three(oio::AppendObjectWriter::new(writer)) - } else if args.content_length().is_some() { - ObsWriters::One(oio::OneShotWriter::new(writer)) + ObsWriters::Two(oio::AppendObjectWriter::new(writer)) } else { - ObsWriters::Two(oio::MultipartUploadWriter::new(writer)) + ObsWriters::One(oio::MultipartUploadWriter::new(writer)) }; let w = if let Some(buffer_size) = args.buffer_size() { diff --git a/core/src/services/obs/writer.rs b/core/src/services/obs/writer.rs index 62882b1ca5d..94b8380a7f8 100644 --- a/core/src/services/obs/writer.rs +++ b/core/src/services/obs/writer.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use async_trait::async_trait; -use bytes::Bytes; use http::StatusCode; use super::core::*; @@ -27,11 +26,8 @@ use crate::raw::oio::MultipartUploadPart; use crate::raw::*; use crate::*; -pub type ObsWriters = oio::ThreeWaysWriter< - oio::OneShotWriter, - oio::MultipartUploadWriter, - oio::AppendObjectWriter, ->; +pub type ObsWriters = + oio::TwoWaysWriter, oio::AppendObjectWriter>; pub struct ObsWriter { core: Arc, @@ -51,15 +47,14 @@ impl ObsWriter { } #[async_trait] -impl oio::OneShotWrite for ObsWriter { - async fn write_once(&self, bs: Bytes) -> Result<()> { - let size = bs.len(); +impl oio::MultipartUploadWrite for ObsWriter { + async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self.core.obs_put_object_request( &self.path, - Some(size as u64), + Some(size), self.op.content_type(), self.op.cache_control(), - AsyncBody::Bytes(bs), + body, )?; self.core.sign(&mut req).await?; @@ -76,10 +71,7 @@ impl oio::OneShotWrite for ObsWriter { _ => Err(parse_error(resp).await?), } } -} -#[async_trait] -impl oio::MultipartUploadWrite for ObsWriter { async fn initiate_part(&self) -> Result { let resp = self .core diff --git a/core/src/services/oss/backend.rs b/core/src/services/oss/backend.rs index 212c42ce064..6d027ba5bb2 100644 --- a/core/src/services/oss/backend.rs +++ b/core/src/services/oss/backend.rs @@ -473,11 +473,9 @@ impl Accessor for OssBackend { let writer = OssWriter::new(self.core.clone(), path, args.clone()); let w = if args.append() { - OssWriters::Three(oio::AppendObjectWriter::new(writer)) - } else if args.content_length().is_some() { - OssWriters::One(oio::OneShotWriter::new(writer)) + OssWriters::Two(oio::AppendObjectWriter::new(writer)) } else { - OssWriters::Two(oio::MultipartUploadWriter::new(writer)) + OssWriters::One(oio::MultipartUploadWriter::new(writer)) }; let w = if let Some(buffer_size) = args.buffer_size() { diff --git a/core/src/services/oss/writer.rs b/core/src/services/oss/writer.rs index 56d262f17ed..296055e29dc 100644 --- a/core/src/services/oss/writer.rs +++ b/core/src/services/oss/writer.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use async_trait::async_trait; -use bytes::Bytes; use http::StatusCode; use super::core::*; @@ -26,11 +25,8 @@ use super::error::parse_error; use crate::raw::*; use crate::*; -pub type OssWriters = oio::ThreeWaysWriter< - oio::OneShotWriter, - oio::MultipartUploadWriter, - oio::AppendObjectWriter, ->; +pub type OssWriters = + oio::TwoWaysWriter, oio::AppendObjectWriter>; pub struct OssWriter { core: Arc, @@ -50,16 +46,15 @@ impl OssWriter { } #[async_trait] -impl oio::OneShotWrite for OssWriter { - async fn write_once(&self, bs: Bytes) -> Result<()> { - let size = bs.len(); +impl oio::MultipartUploadWrite for OssWriter { + async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self.core.oss_put_object_request( &self.path, - Some(size as u64), + Some(size), self.op.content_type(), self.op.content_disposition(), self.op.cache_control(), - AsyncBody::Bytes(bs), + body, false, )?; @@ -77,10 +72,7 @@ impl oio::OneShotWrite for OssWriter { _ => Err(parse_error(resp).await?), } } -} -#[async_trait] -impl oio::MultipartUploadWrite for OssWriter { async fn initiate_part(&self) -> Result { let resp = self .core diff --git a/core/src/services/s3/backend.rs b/core/src/services/s3/backend.rs index d2d8bb039c3..aa2b95eb841 100644 --- a/core/src/services/s3/backend.rs +++ b/core/src/services/s3/backend.rs @@ -977,11 +977,7 @@ impl Accessor for S3Backend { async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { let writer = S3Writer::new(self.core.clone(), path, args.clone()); - let w = if args.content_length().is_some() { - S3Writers::One(oio::OneShotWriter::new(writer)) - } else { - S3Writers::Two(oio::MultipartUploadWriter::new(writer)) - }; + let w = oio::MultipartUploadWriter::new(writer); let w = if let Some(buffer_size) = args.buffer_size() { let buffer_size = max(MINIMUM_MULTIPART_SIZE, buffer_size); diff --git a/core/src/services/s3/writer.rs b/core/src/services/s3/writer.rs index a3a1bd5bdd4..76c874ed0c1 100644 --- a/core/src/services/s3/writer.rs +++ b/core/src/services/s3/writer.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use async_trait::async_trait; -use bytes::Bytes; use http::StatusCode; use super::core::*; @@ -26,8 +25,7 @@ use super::error::parse_error; use crate::raw::*; use crate::*; -pub type S3Writers = - oio::TwoWaysWriter, oio::MultipartUploadWriter>; +pub type S3Writers = oio::MultipartUploadWriter; pub struct S3Writer { core: Arc, @@ -47,17 +45,15 @@ impl S3Writer { } #[async_trait] -impl oio::OneShotWrite for S3Writer { - async fn write_once(&self, bs: Bytes) -> Result<()> { - let size = bs.len(); - +impl oio::MultipartUploadWrite for S3Writer { + async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self.core.s3_put_object_request( &self.path, - Some(size as u64), + Some(size), self.op.content_type(), self.op.content_disposition(), self.op.cache_control(), - AsyncBody::Bytes(bs), + body, )?; self.core.sign(&mut req).await?; @@ -74,10 +70,7 @@ impl oio::OneShotWrite for S3Writer { _ => Err(parse_error(resp).await?), } } -} -#[async_trait] -impl oio::MultipartUploadWrite for S3Writer { async fn initiate_part(&self) -> Result { let resp = self .core