diff --git a/Cargo.lock b/Cargo.lock index 242226a3f..a28c708e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -624,6 +624,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -1934,6 +1940,7 @@ dependencies = [ "aws-config", "aws-credential-types", "aws-sdk-s3", + "base64ct", "built", "bytes", "clap 3.2.25", diff --git a/mountpoint-s3-client/Cargo.toml b/mountpoint-s3-client/Cargo.toml index 8de524e8a..25bb0738c 100644 --- a/mountpoint-s3-client/Cargo.toml +++ b/mountpoint-s3-client/Cargo.toml @@ -38,6 +38,7 @@ anyhow = { version = "1.0.64", features = ["backtrace"] } aws-config = "0.54.1" aws-credential-types = "0.54.1" aws-sdk-s3 = "0.24.0" +base64ct = { version = "1.6.0", features = ["std"] } bytes = "1.2.1" clap = "3.2.12" ctor = "0.1.23" diff --git a/mountpoint-s3-client/src/object_client.rs b/mountpoint-s3-client/src/object_client.rs index 9fe573f0c..8ef697c92 100644 --- a/mountpoint-s3-client/src/object_client.rs +++ b/mountpoint-s3-client/src/object_client.rs @@ -247,7 +247,23 @@ pub enum GetObjectAttributesError { /// TODO: Populate this struct with parameters from the S3 API, e.g., storage class, encryption. #[derive(Debug, Default, Clone)] #[non_exhaustive] -pub struct PutObjectParams {} +pub struct PutObjectParams { + /// Enable Crc32c trailing checksums. + pub trailing_checksums: bool, +} + +impl PutObjectParams { + /// Create a default [PutObjectParams]. + pub fn new() -> Self { + Self::default() + } + + /// Set Crc32c trailing checksums. + pub fn trailing_checksums(mut self, value: bool) -> Self { + self.trailing_checksums = value; + self + } +} /// A streaming put request which allows callers to asynchronously write /// the body of the request. diff --git a/mountpoint-s3-client/src/s3_crt_client.rs b/mountpoint-s3-client/src/s3_crt_client.rs index b5a100f82..980022c8d 100644 --- a/mountpoint-s3-client/src/s3_crt_client.rs +++ b/mountpoint-s3-client/src/s3_crt_client.rs @@ -20,8 +20,8 @@ use mountpoint_s3_crt::io::event_loop::EventLoopGroup; use mountpoint_s3_crt::io::host_resolver::{HostResolver, HostResolverDefaultOptions}; use mountpoint_s3_crt::io::retry_strategy::{ExponentialBackoffJitterMode, RetryStrategy, StandardRetryOptions}; use mountpoint_s3_crt::s3::client::{ - init_default_signing_config, Client, ClientConfig, MetaRequestOptions, MetaRequestResult, MetaRequestType, - RequestType, + init_default_signing_config, ChecksumConfig, Client, ClientConfig, MetaRequestOptions, MetaRequestResult, + MetaRequestType, RequestType, }; use async_trait::async_trait; @@ -246,6 +246,7 @@ impl S3CrtClientInner { inner: message, uri, path_prefix, + checksum_config: None, }) } @@ -273,6 +274,9 @@ impl S3CrtClientInner { let first_body_part_clone = Arc::clone(&first_body_part); let mut options = MetaRequestOptions::new(); + if let Some(checksum_config) = message.checksum_config { + options.checksum_config(checksum_config); + } options .message(message.inner) .endpoint(message.uri) @@ -457,6 +461,7 @@ struct S3Message { inner: Message, uri: Uri, path_prefix: String, + checksum_config: Option, } impl S3Message { @@ -528,6 +533,11 @@ impl S3Message { fn set_body_stream(&mut self, input_stream: Option) -> Option { self.inner.set_body_stream(input_stream) } + + /// Sets the checksum configuration for this message. + fn set_checksum_config(&mut self, checksum_config: Option) { + self.checksum_config = checksum_config; + } } #[derive(Debug)] diff --git a/mountpoint-s3-client/src/s3_crt_client/put_object.rs b/mountpoint-s3-client/src/s3_crt_client/put_object.rs index bae12f7ee..951788cb9 100644 --- a/mountpoint-s3-client/src/s3_crt_client/put_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/put_object.rs @@ -5,7 +5,7 @@ use crate::{ObjectClientError, PutObjectRequest, PutObjectResult, S3CrtClient, S use async_trait::async_trait; use mountpoint_s3_crt::http::request_response::Header; use mountpoint_s3_crt::io::async_stream::{self, AsyncStreamWriter}; -use mountpoint_s3_crt::s3::client::MetaRequestType; +use mountpoint_s3_crt::s3::client::{ChecksumConfig, MetaRequestType}; use tracing::{debug, Span}; use super::{S3CrtClientInner, S3HttpRequest}; @@ -67,6 +67,11 @@ impl S3PutObjectRequest { .set_request_path(&key) .map_err(S3RequestError::construction_failure)?; + if self.params.trailing_checksums { + let checksum_config = ChecksumConfig::trailing_crc32c(); + message.set_checksum_config(Some(checksum_config)); + } + let (body_async_stream, writer) = async_stream::new_stream(&self.client.allocator); message.set_body_stream(Some(body_async_stream)); diff --git a/mountpoint-s3-client/tests/put_object.rs b/mountpoint-s3-client/tests/put_object.rs index 99a514b1d..56460f73b 100644 --- a/mountpoint-s3-client/tests/put_object.rs +++ b/mountpoint-s3-client/tests/put_object.rs @@ -2,12 +2,18 @@ pub mod common; +use base64ct::Base64; +use base64ct::Encoding; use common::*; use futures::{pin_mut, StreamExt}; use mountpoint_s3_client::GetObjectError; use mountpoint_s3_client::ObjectClient; use mountpoint_s3_client::ObjectClientResult; +use mountpoint_s3_client::PutObjectParams; use mountpoint_s3_client::PutObjectRequest; +use mountpoint_s3_client::S3ClientConfig; +use mountpoint_s3_client::S3CrtClient; +use mountpoint_s3_crt::checksums::crc32c; use rand::Rng; // Simple test for PUT object. Puts a single, small object as a single part and checks that the @@ -184,3 +190,48 @@ async fn test_put_object_abort() { let uploads_in_progress = get_mpu_count_for_key(&sdk_client, &bucket, &key).await.unwrap(); assert_eq!(uploads_in_progress, 0); } + +#[tokio::test] +async fn test_put_checksums() { + const PART_SIZE: usize = 5 * 1024 * 1024; + let (bucket, prefix) = get_test_bucket_and_prefix("test_put_checksums"); + let client_config = S3ClientConfig { + throughput_target_gbps: Some(10.0), + part_size: Some(PART_SIZE), + ..Default::default() + }; + let client = S3CrtClient::new(&get_test_region(), client_config).expect("could not create test client"); + let key = format!("{prefix}hello"); + + let mut rng = rand::thread_rng(); + let mut contents = vec![0u8; PART_SIZE * 2]; + rng.fill(&mut contents[..]); + + let params = PutObjectParams::new().trailing_checksums(true); + let mut request = client + .put_object(&bucket, &key, ¶ms) + .await + .expect("put_object should succeed"); + + request.write(&contents).await.unwrap(); + request.complete().await.unwrap(); + + let sdk_client = get_test_sdk_client().await; + let attributes = sdk_client + .get_object_attributes() + .bucket(bucket) + .key(key) + .object_attributes(aws_sdk_s3::model::ObjectAttributes::ObjectParts) + .send() + .await + .unwrap(); + let parts = attributes.object_parts().unwrap().parts().unwrap(); + let checksums: Vec<_> = parts.iter().map(|p| p.checksum_crc32_c().unwrap()).collect(); + let expected_checksums: Vec<_> = contents.chunks(PART_SIZE).map(crc32c::checksum).collect(); + + assert_eq!(checksums.len(), expected_checksums.len()); + for (checksum, expected_checksum) in checksums.into_iter().zip(expected_checksums.into_iter()) { + let encoded = Base64::encode_string(&expected_checksum.value().to_be_bytes()); + assert_eq!(checksum, encoded); + } +} diff --git a/mountpoint-s3-crt/src/checksums/crc32.rs b/mountpoint-s3-crt/src/checksums/crc32.rs index 55aff5074..119abf7d1 100644 --- a/mountpoint-s3-crt/src/checksums/crc32.rs +++ b/mountpoint-s3-crt/src/checksums/crc32.rs @@ -9,6 +9,11 @@ impl Crc32 { pub fn new(value: u32) -> Crc32 { Crc32(value) } + + /// The CRC32 checksum value. + pub fn value(&self) -> u32 { + self.0 + } } /// Computes the CRC32 checksum of a byte slice. diff --git a/mountpoint-s3-crt/src/checksums/crc32c.rs b/mountpoint-s3-crt/src/checksums/crc32c.rs index f0e17706d..f9630f625 100644 --- a/mountpoint-s3-crt/src/checksums/crc32c.rs +++ b/mountpoint-s3-crt/src/checksums/crc32c.rs @@ -9,6 +9,11 @@ impl Crc32c { pub fn new(value: u32) -> Crc32c { Crc32c(value) } + + /// The CRC32C checksum value. + pub fn value(&self) -> u32 { + self.0 + } } /// Computes the CRC32C checksum of a byte slice. diff --git a/mountpoint-s3-crt/src/s3/client.rs b/mountpoint-s3-crt/src/s3/client.rs index 95e7722b1..d0c5758d7 100644 --- a/mountpoint-s3-crt/src/s3/client.rs +++ b/mountpoint-s3-crt/src/s3/client.rs @@ -136,6 +136,9 @@ struct MetaRequestOptionsInner { /// Owned signing config, if provided. signing_config: Option, + /// Owned checksum config, if provided. + checksum_config: Option, + /// Telemetry callback, if provided on_telemetry: Option, @@ -210,6 +213,7 @@ impl MetaRequestOptions { message: None, endpoint: None, signing_config: None, + checksum_config: None, on_telemetry: None, on_headers: None, on_body: None, @@ -255,6 +259,16 @@ impl MetaRequestOptions { self } + /// Set the checksum config used for this message. + pub fn checksum_config(&mut self, checksum_config: ChecksumConfig) -> &mut Self { + // SAFETY: we aren't moving out of the struct. + let options = unsafe { Pin::get_unchecked_mut(Pin::as_mut(&mut self.0)) }; + options.checksum_config = Some(checksum_config); + options.inner.checksum_config = + options.checksum_config.as_mut().unwrap().to_inner_ptr() as *mut aws_s3_checksum_config; + self + } + /// Set the signing config used for this message. Not public because we copy it from the client /// when making a request. fn signing_config(&mut self, signing_config: SigningConfig) -> &mut Self { @@ -829,3 +843,28 @@ pub fn init_default_signing_config(region: &str, credentials_provider: Credentia SigningConfig(Arc::new(Box::into_pin(signing_config))) } + +/// The checksum configuration. +#[derive(Debug, Clone, Default)] +pub struct ChecksumConfig { + /// The struct we can pass into the CRT's functions. + inner: aws_s3_checksum_config, +} + +impl ChecksumConfig { + /// Create a [ChecksumConfig] enabling Crc32c trailing checksums in PUT requests. + pub fn trailing_crc32c() -> Self { + Self { + inner: aws_s3_checksum_config { + location: aws_s3_checksum_location::AWS_SCL_TRAILER, + checksum_algorithm: aws_s3_checksum_algorithm::AWS_SCA_CRC32C, + ..Default::default() + }, + } + } + + /// Get out the inner pointer to the checksum config + pub(crate) fn to_inner_ptr(&self) -> *const aws_s3_checksum_config { + &self.inner + } +} diff --git a/mountpoint-s3/src/upload.rs b/mountpoint-s3/src/upload.rs index 57c70565b..bff612bab 100644 --- a/mountpoint-s3/src/upload.rs +++ b/mountpoint-s3/src/upload.rs @@ -68,10 +68,8 @@ where bucket: &str, key: &str, ) -> ObjectClientResult { - let request = inner - .client - .put_object(bucket, key, &PutObjectParams::default()) - .await?; + let params = PutObjectParams::new().trailing_checksums(true); + let request = inner.client.put_object(bucket, key, ¶ms).await?; Ok(Self { bucket: bucket.to_owned(),