Skip to content

Commit

Permalink
refactor: Polish multipart writer to allow oneshot optimization (#3031)
Browse files Browse the repository at this point in the history
* polish multipart writer

Signed-off-by: Xuanwo <[email protected]>

* Fix doctest

Signed-off-by: Xuanwo <[email protected]>

---------

Signed-off-by: Xuanwo <[email protected]>
  • Loading branch information
Xuanwo authored Sep 11, 2023
1 parent 4a52ba6 commit bf85c1f
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 87 deletions.
98 changes: 82 additions & 16 deletions core/src/raw/oio/write/multipart_upload_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::task::Context;
use std::task::Poll;

use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;

use crate::raw::*;
Expand All @@ -37,8 +38,26 @@ use crate::*;
/// - Services impl `MultipartUploadWrite`
/// - `MultipartUploadWriter` impl `Write`
/// - Expose `MultipartUploadWriter` as `Accessor::Writer`
///
/// # Notes
///
/// `MultipartUploadWrite` has an oneshot optimization when `write` has been called only once:
///
/// ```no_build
/// w.write(bs).await?;
/// w.close().await?;
/// ```
///
/// We will use `write_once` instead of starting a new multipart upload.
#[async_trait]
pub trait MultipartUploadWrite: Send + Sync + Unpin + 'static {
/// write_once is used to write the data to underlying storage at once.
///
/// MultipartUploadWriter will call this API when:
///
/// - All the data has been written to the buffer and we can perform the upload at once.
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>;

/// initiate_part will call start a multipart upload and return the upload id.
///
/// MultipartUploadWriter will call this when:
Expand Down Expand Up @@ -90,14 +109,15 @@ pub struct MultipartUploadPart {
pub struct MultipartUploadWriter<W: MultipartUploadWrite> {
state: State<W>,

cache: Option<Bytes>,
upload_id: Option<Arc<String>>,
parts: Vec<MultipartUploadPart>,
}

enum State<W> {
Idle(Option<W>),
Init(BoxFuture<'static, (W, Result<String>)>),
Write(BoxFuture<'static, (W, usize, Result<MultipartUploadPart>)>),
Write(BoxFuture<'static, (W, Result<MultipartUploadPart>)>),
Close(BoxFuture<'static, (W, Result<()>)>),
Abort(BoxFuture<'static, (W, Result<()>)>),
}
Expand All @@ -113,6 +133,7 @@ impl<W: MultipartUploadWrite> MultipartUploadWriter<W> {
Self {
state: State::Idle(Some(inner)),

cache: None,
upload_id: None,
parts: Vec::new(),
}
Expand All @@ -128,15 +149,15 @@ where
loop {
match &mut self.state {
State::Idle(w) => {
let w = w.take().expect("writer must be valid");
match self.upload_id.as_ref() {
Some(upload_id) => {
let size = bs.remaining();
let bs = bs.copy_to_bytes(size);
let upload_id = upload_id.clone();
let part_number = self.parts.len();

let bs = self.cache.clone().expect("cache must be valid").clone();
let w = w.take().expect("writer must be valid");
self.state = State::Write(Box::pin(async move {
let size = bs.len();
let part = w
.write_part(
&upload_id,
Expand All @@ -146,10 +167,18 @@ where
)
.await;

(w, size, part)
(w, part)
}));
}
None => {
// Fill cache with the first write.
if self.cache.is_none() {
let size = bs.remaining();
self.cache = Some(bs.copy_to_bytes(size));
return Poll::Ready(Ok(size));
}

let w = w.take().expect("writer must be valid");
self.state = State::Init(Box::pin(async move {
let upload_id = w.initiate_part().await;
(w, upload_id)
Expand All @@ -163,10 +192,12 @@ where
self.upload_id = Some(Arc::new(upload_id?));
}
State::Write(fut) => {
let (w, size, part) = ready!(fut.as_mut().poll(cx));
let (w, part) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));

self.parts.push(part?);
// Replace the cache when last write succeeded
let size = bs.remaining();
self.cache = Some(bs.copy_to_bytes(size));
return Poll::Ready(Ok(size));
}
State::Close(_) => {
Expand All @@ -191,25 +222,57 @@ where
match self.upload_id.clone() {
Some(upload_id) => {
let parts = self.parts.clone();
self.state = State::Close(Box::pin(async move {
let res = w.complete_part(&upload_id, &parts).await;
(w, res)
}));
match self.cache.clone() {
Some(bs) => {
let upload_id = upload_id.clone();
self.state = State::Write(Box::pin(async move {
let size = bs.len();
let part = w
.write_part(
&upload_id,
parts.len(),
size as u64,
AsyncBody::Bytes(bs),
)
.await;
(w, part)
}));
}
None => {
self.state = State::Close(Box::pin(async move {
let res = w.complete_part(&upload_id, &parts).await;
(w, res)
}));
}
}
}
None => return Poll::Ready(Ok(())),
None => match self.cache.clone() {
Some(bs) => {
self.state = State::Close(Box::pin(async move {
let size = bs.len();
let res = w.write_once(size as u64, AsyncBody::Bytes(bs)).await;
(w, res)
}));
}
None => return Poll::Ready(Ok(())),
},
}
}
State::Close(fut) => {
let (w, res) = futures::ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
self.cache = None;
return Poll::Ready(res);
}
State::Init(_) => unreachable!(
"MultipartUploadWriter must not go into State::Init during poll_close"
),
State::Write(_) => unreachable!(
"MultipartUploadWriter must not go into State::Write during poll_close"
),
State::Write(fut) => {
let (w, part) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
self.parts.push(part?);
self.cache = None;
}
State::Abort(_) => unreachable!(
"MultipartUploadWriter must not go into State::Abort during poll_close"
),
Expand All @@ -229,7 +292,10 @@ where
(w, res)
}));
}
None => return Poll::Ready(Ok(())),
None => {
self.cache = None;
return Poll::Ready(Ok(()));
}
}
}
State::Abort(fut) => {
Expand Down
6 changes: 2 additions & 4 deletions core/src/services/cos/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,9 @@ impl Accessor for CosBackend {
let writer = CosWriter::new(self.core.clone(), path, args.clone());

let w = if args.append() {
CosWriters::Three(oio::AppendObjectWriter::new(writer))
} else if args.content_length().is_some() {
CosWriters::One(oio::OneShotWriter::new(writer))
CosWriters::Two(oio::AppendObjectWriter::new(writer))
} else {
CosWriters::Two(oio::MultipartUploadWriter::new(writer))
CosWriters::One(oio::MultipartUploadWriter::new(writer))
};

let w = if let Some(buffer_size) = args.buffer_size() {
Expand Down
20 changes: 6 additions & 14 deletions core/src/services/cos/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use http::StatusCode;

use super::core::*;
use super::error::parse_error;
use crate::raw::*;
use crate::*;

pub type CosWriters = oio::ThreeWaysWriter<
oio::OneShotWriter<CosWriter>,
oio::MultipartUploadWriter<CosWriter>,
oio::AppendObjectWriter<CosWriter>,
>;
pub type CosWriters =
oio::TwoWaysWriter<oio::MultipartUploadWriter<CosWriter>, oio::AppendObjectWriter<CosWriter>>;

pub struct CosWriter {
core: Arc<CosCore>,
Expand All @@ -50,16 +46,15 @@ impl CosWriter {
}

#[async_trait]
impl oio::OneShotWrite for CosWriter {
async fn write_once(&self, buf: Bytes) -> Result<()> {
let size = buf.len();
impl oio::MultipartUploadWrite for CosWriter {
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
let mut req = self.core.cos_put_object_request(
&self.path,
Some(size as u64),
Some(size),
self.op.content_type(),
self.op.content_disposition(),
self.op.cache_control(),
AsyncBody::Bytes(buf),
body,
)?;

self.core.sign(&mut req).await?;
Expand All @@ -76,10 +71,7 @@ impl oio::OneShotWrite for CosWriter {
_ => Err(parse_error(resp).await?),
}
}
}

#[async_trait]
impl oio::MultipartUploadWrite for CosWriter {
async fn initiate_part(&self) -> Result<String> {
let resp = self
.core
Expand Down
6 changes: 2 additions & 4 deletions core/src/services/obs/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,9 @@ impl Accessor for ObsBackend {
let writer = ObsWriter::new(self.core.clone(), path, args.clone());

let w = if args.append() {
ObsWriters::Three(oio::AppendObjectWriter::new(writer))
} else if args.content_length().is_some() {
ObsWriters::One(oio::OneShotWriter::new(writer))
ObsWriters::Two(oio::AppendObjectWriter::new(writer))
} else {
ObsWriters::Two(oio::MultipartUploadWriter::new(writer))
ObsWriters::One(oio::MultipartUploadWriter::new(writer))
};

let w = if let Some(buffer_size) = args.buffer_size() {
Expand Down
20 changes: 6 additions & 14 deletions core/src/services/obs/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use http::StatusCode;

use super::core::*;
Expand All @@ -27,11 +26,8 @@ use crate::raw::oio::MultipartUploadPart;
use crate::raw::*;
use crate::*;

pub type ObsWriters = oio::ThreeWaysWriter<
oio::OneShotWriter<ObsWriter>,
oio::MultipartUploadWriter<ObsWriter>,
oio::AppendObjectWriter<ObsWriter>,
>;
pub type ObsWriters =
oio::TwoWaysWriter<oio::MultipartUploadWriter<ObsWriter>, oio::AppendObjectWriter<ObsWriter>>;

pub struct ObsWriter {
core: Arc<ObsCore>,
Expand All @@ -51,15 +47,14 @@ impl ObsWriter {
}

#[async_trait]
impl oio::OneShotWrite for ObsWriter {
async fn write_once(&self, bs: Bytes) -> Result<()> {
let size = bs.len();
impl oio::MultipartUploadWrite for ObsWriter {
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
let mut req = self.core.obs_put_object_request(
&self.path,
Some(size as u64),
Some(size),
self.op.content_type(),
self.op.cache_control(),
AsyncBody::Bytes(bs),
body,
)?;

self.core.sign(&mut req).await?;
Expand All @@ -76,10 +71,7 @@ impl oio::OneShotWrite for ObsWriter {
_ => Err(parse_error(resp).await?),
}
}
}

#[async_trait]
impl oio::MultipartUploadWrite for ObsWriter {
async fn initiate_part(&self) -> Result<String> {
let resp = self
.core
Expand Down
6 changes: 2 additions & 4 deletions core/src/services/oss/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,9 @@ impl Accessor for OssBackend {
let writer = OssWriter::new(self.core.clone(), path, args.clone());

let w = if args.append() {
OssWriters::Three(oio::AppendObjectWriter::new(writer))
} else if args.content_length().is_some() {
OssWriters::One(oio::OneShotWriter::new(writer))
OssWriters::Two(oio::AppendObjectWriter::new(writer))
} else {
OssWriters::Two(oio::MultipartUploadWriter::new(writer))
OssWriters::One(oio::MultipartUploadWriter::new(writer))
};

let w = if let Some(buffer_size) = args.buffer_size() {
Expand Down
20 changes: 6 additions & 14 deletions core/src/services/oss/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use http::StatusCode;

use super::core::*;
use super::error::parse_error;
use crate::raw::*;
use crate::*;

pub type OssWriters = oio::ThreeWaysWriter<
oio::OneShotWriter<OssWriter>,
oio::MultipartUploadWriter<OssWriter>,
oio::AppendObjectWriter<OssWriter>,
>;
pub type OssWriters =
oio::TwoWaysWriter<oio::MultipartUploadWriter<OssWriter>, oio::AppendObjectWriter<OssWriter>>;

pub struct OssWriter {
core: Arc<OssCore>,
Expand All @@ -50,16 +46,15 @@ impl OssWriter {
}

#[async_trait]
impl oio::OneShotWrite for OssWriter {
async fn write_once(&self, bs: Bytes) -> Result<()> {
let size = bs.len();
impl oio::MultipartUploadWrite for OssWriter {
async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> {
let mut req = self.core.oss_put_object_request(
&self.path,
Some(size as u64),
Some(size),
self.op.content_type(),
self.op.content_disposition(),
self.op.cache_control(),
AsyncBody::Bytes(bs),
body,
false,
)?;

Expand All @@ -77,10 +72,7 @@ impl oio::OneShotWrite for OssWriter {
_ => Err(parse_error(resp).await?),
}
}
}

#[async_trait]
impl oio::MultipartUploadWrite for OssWriter {
async fn initiate_part(&self) -> Result<String> {
let resp = self
.core
Expand Down
Loading

0 comments on commit bf85c1f

Please sign in to comment.