diff --git a/.github_changelog_generator b/.github_changelog_generator new file mode 100644 index 0000000..6d44a8e --- /dev/null +++ b/.github_changelog_generator @@ -0,0 +1,27 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Add special sections for documentation, security and performance +add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":["documentation"]},"security":{"prefix":"**Security updates:**","labels":["security"]},"performance":{"prefix":"**Performance improvements:**","labels":["performance"]}} +# so that the component is shown associated with the issue +issue-line-labels=object-store +# skip non object_store issues +exclude-labels=development-process,invalid,arrow,parquet,arrow-flight,parquet-derive,question +breaking_labels=api-change diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0788dae --- /dev/null +++ b/.gitignore @@ -0,0 +1,99 @@ +Cargo.lock +target +rusty-tags.vi +.history +.flatbuffers/ +.idea/ +.vscode +.devcontainer +venv/* +# created by doctests +parquet/data.parquet +# release notes cache +.githubchangeloggenerator.cache +.githubchangeloggenerator.cache.log +justfile +.prettierignore +.env +.editorconfig +# local azurite file +__azurite* +__blobstorage__ + +# .bak files +*.bak +*.bak2 +# OS-specific .gitignores + +# Mac .gitignore +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Linux .gitignore +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +# Windows .gitignore +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# Python virtual env in parquet crate +parquet/pytest/venv/ +__pycache__/ diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md new file mode 100644 index 0000000..f157e6f --- /dev/null +++ b/CHANGELOG-old.md @@ -0,0 +1,833 @@ + + +# Historical Changelog + +## [object_store_0.11.2](https://github.com/apache/arrow-rs/tree/object_store_0.11.2) (2024-12-20) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.11.1...object_store_0.11.2) + +**Implemented enhancements:** + +- object-store's AzureClient should protect against multiple streams performing put\_block in parallel for the same BLOB path [\#6868](https://github.com/apache/arrow-rs/issues/6868) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support S3 Put IfMatch [\#6799](https://github.com/apache/arrow-rs/issues/6799) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store Azure Government using OAuth [\#6759](https://github.com/apache/arrow-rs/issues/6759) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support for AWS Requester Pays buckets [\#6716](https://github.com/apache/arrow-rs/issues/6716) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object-store\]: Implement credential\_process support for S3 [\#6422](https://github.com/apache/arrow-rs/issues/6422) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Conditional put and rename\_if\_not\_exist on S3 [\#6285](https://github.com/apache/arrow-rs/issues/6285) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- `object_store` errors when `reqwest` `gzip` feature is enabled [\#6842](https://github.com/apache/arrow-rs/issues/6842) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Multi-part s3 uploads fail when using checksum [\#6793](https://github.com/apache/arrow-rs/issues/6793) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- `with_unsigned_payload` shouldn't generate payload hash [\#6697](https://github.com/apache/arrow-rs/issues/6697) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[Object\_store\] min\_ttl is too high for GKE tokens [\#6625](https://github.com/apache/arrow-rs/issues/6625) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store `test_private_bucket` fails - store: "S3", source: BucketNotFound { bucket: "bloxbender" } [\#6600](https://github.com/apache/arrow-rs/issues/6600) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- S3 endpoint and trailing slash result in weird/invalid requests [\#6580](https://github.com/apache/arrow-rs/issues/6580) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Use randomized content ID for Azure multipart uploads [\#6869](https://github.com/apache/arrow-rs/pull/6869) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([avarnon](https://github.com/avarnon)) +- Always explicitly disable `gzip` automatic decompression on reqwest client used by object\_store [\#6843](https://github.com/apache/arrow-rs/pull/6843) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([phillipleblanc](https://github.com/phillipleblanc)) +- object-store: remove S3ConditionalPut::ETagPutIfNotExists [\#6802](https://github.com/apache/arrow-rs/pull/6802) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([benesch](https://github.com/benesch)) +- Fix multipart uploads with checksums on object locked buckets [\#6794](https://github.com/apache/arrow-rs/pull/6794) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Add AuthorityHost to AzureConfigKey [\#6773](https://github.com/apache/arrow-rs/pull/6773) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([zadeluca](https://github.com/zadeluca)) +- object\_store: Add support for requester pays buckets [\#6768](https://github.com/apache/arrow-rs/pull/6768) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kylebarron](https://github.com/kylebarron)) +- check sign\_payload instead of skip\_signature before computing checksum [\#6698](https://github.com/apache/arrow-rs/pull/6698) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mherrerarendon](https://github.com/mherrerarendon)) +- Update quick-xml requirement from 0.36.0 to 0.37.0 in /object\_store [\#6687](https://github.com/apache/arrow-rs/pull/6687) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- Support native S3 conditional writes [\#6682](https://github.com/apache/arrow-rs/pull/6682) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([benesch](https://github.com/benesch)) +- \[object\_store\] fix S3 endpoint and trailing slash result in invalid requests [\#6641](https://github.com/apache/arrow-rs/pull/6641) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([adbmal](https://github.com/adbmal)) +- Lower GCP token min\_ttl to 4 minutes and add backoff to token refresh logic [\#6638](https://github.com/apache/arrow-rs/pull/6638) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mwylde](https://github.com/mwylde)) +- Remove `test_private_bucket` object\_store test [\#6601](https://github.com/apache/arrow-rs/pull/6601) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + +## [object_store_0.11.1](https://github.com/apache/arrow-rs/tree/object_store_0.11.1) (2024-10-15) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.11.0...object_store_0.11.1) + +**Implemented enhancements:** + +- There is no way to pass object store client options as environment variables [\#6333](https://github.com/apache/arrow-rs/issues/6333) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Better Document Backoff Algorithm [\#6324](https://github.com/apache/arrow-rs/issues/6324) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add direction to `list_with_offset` [\#6274](https://github.com/apache/arrow-rs/issues/6274) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support server-side encryption with customer-provided keys \(SSE-C\) [\#6229](https://github.com/apache/arrow-rs/issues/6229) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- \[object-store\] Requested tokio version is too old - does not compile [\#6458](https://github.com/apache/arrow-rs/issues/6458) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Azure SAS tokens are visible when retry errors are logged via object\_store [\#6322](https://github.com/apache/arrow-rs/issues/6322) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- object\_store: fix typo in with\_connect\_timeout\_disabled that actually disabled non-connect timeouts [\#6563](https://github.com/apache/arrow-rs/pull/6563) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([adriangb](https://github.com/adriangb)) +- object\_store: Clarify what is a prefix in list\(\) documentation [\#6520](https://github.com/apache/arrow-rs/pull/6520) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([progval](https://github.com/progval)) +- object\_store: enable lint `unreachable_pub` [\#6512](https://github.com/apache/arrow-rs/pull/6512) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) +- \[object\_store\] Retry S3 requests with 200 response with "Error" in body [\#6508](https://github.com/apache/arrow-rs/pull/6508) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([PeterKeDer](https://github.com/PeterKeDer)) +- \[object-store\] Require tokio 1.29.0. [\#6459](https://github.com/apache/arrow-rs/pull/6459) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ashtuchkin](https://github.com/ashtuchkin)) +- feat: expose HTTP/2 max frame size in `object_store` [\#6442](https://github.com/apache/arrow-rs/pull/6442) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- Derive `Clone` for `object_store::aws::AmazonS3` [\#6414](https://github.com/apache/arrow-rs/pull/6414) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ethe](https://github.com/ethe)) +- object\_score: Support Azure Fabric OAuth Provider [\#6382](https://github.com/apache/arrow-rs/pull/6382) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([RobinLin666](https://github.com/RobinLin666)) +- `object_store::GetOptions` derive `Clone` [\#6361](https://github.com/apache/arrow-rs/pull/6361) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([samuelcolvin](https://github.com/samuelcolvin)) +- \[object\_store\] Propagate env vars as object store client options [\#6334](https://github.com/apache/arrow-rs/pull/6334) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ccciudatu](https://github.com/ccciudatu)) +- docs\[object\_store\]: clarify the backoff strategy that is actually implemented [\#6325](https://github.com/apache/arrow-rs/pull/6325) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([westonpace](https://github.com/westonpace)) +- fix: azure sas token visible in logs [\#6323](https://github.com/apache/arrow-rs/pull/6323) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- object\_store/delimited: Fix `TrailingEscape` condition [\#6265](https://github.com/apache/arrow-rs/pull/6265) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- fix\(object\_store\): only add encryption headers for SSE-C in get request [\#6260](https://github.com/apache/arrow-rs/pull/6260) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jiachengdb](https://github.com/jiachengdb)) +- docs: Add parquet\_opendal in related projects [\#6236](https://github.com/apache/arrow-rs/pull/6236) ([Xuanwo](https://github.com/Xuanwo)) +- feat\(object\_store\): add support for server-side encryption with customer-provided keys \(SSE-C\) [\#6230](https://github.com/apache/arrow-rs/pull/6230) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jiachengdb](https://github.com/jiachengdb)) +- feat: further TLS options on ClientOptions: \#5034 [\#6148](https://github.com/apache/arrow-rs/pull/6148) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) + + + +## [object_store_0.11.0](https://github.com/apache/arrow-rs/tree/object_store_0.11.0) (2024-08-12) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.10.2...object_store_0.11.0) + +**Breaking changes:** + +- Make object\_store errors non-exhaustive [\#6165](https://github.com/apache/arrow-rs/pull/6165) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update snafu to `0.8.0` in object\_store \(\#5930\) [\#6070](https://github.com/apache/arrow-rs/pull/6070) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + + +**Merged pull requests:** + +- Add LICENSE and NOTICE files to object_store [\#6234](https://github.com/apache/arrow-rs/pull/6234) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- feat\(object\_store\): add `PermissionDenied` variant to top-level error [\#6194](https://github.com/apache/arrow-rs/pull/6194) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kyle-mccarthy](https://github.com/kyle-mccarthy)) +- Update object store MSRV to `1.64` [\#6123](https://github.com/apache/arrow-rs/pull/6123) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Fix clippy in object\_store crate [\#6120](https://github.com/apache/arrow-rs/pull/6120) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + +## [object_store_0.10.2](https://github.com/apache/arrow-rs/tree/object_store_0.10.2) (2024-07-17) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.10.1...object_store_0.10.2) + +**Implemented enhancements:** + +- Relax `WriteMultipart` API to support aborting after completion [\#5977](https://github.com/apache/arrow-rs/issues/5977) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Make ObjectStoreScheme in the object\_store crate public [\#5911](https://github.com/apache/arrow-rs/issues/5911) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add BufUploader to implement same feature upon `WriteMultipart` like `BufWriter` [\#5834](https://github.com/apache/arrow-rs/issues/5834) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- Investigate why `InstanceCredentialProvider::cache` is flagged as dead code [\#5884](https://github.com/apache/arrow-rs/issues/5884) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object\_store\] Potential race condition in `list_with_delimiter` on `Local` [\#5800](https://github.com/apache/arrow-rs/issues/5800) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Documentation updates:** + +- Correct timeout in comment from 5s to 30s [\#6073](https://github.com/apache/arrow-rs/pull/6073) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([trungda](https://github.com/trungda)) +- docs: Fix broken links of object\_store\_opendal README [\#5929](https://github.com/apache/arrow-rs/pull/5929) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Xuanwo](https://github.com/Xuanwo)) +- docs: Add object\_store\_opendal as related projects [\#5926](https://github.com/apache/arrow-rs/pull/5926) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Xuanwo](https://github.com/Xuanwo)) +- chore: update docs to delineate which ObjectStore lists are recursive [\#5794](https://github.com/apache/arrow-rs/pull/5794) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wiedld](https://github.com/wiedld)) +- Document object store release cadence [\#5750](https://github.com/apache/arrow-rs/pull/5750) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Sanitize error message for sensitive requests [\#6074](https://github.com/apache/arrow-rs/pull/6074) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update quick-xml requirement from 0.35.0 to 0.36.0 in /object\_store [\#6032](https://github.com/apache/arrow-rs/pull/6032) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- use GCE metadata server env var overrides [\#6015](https://github.com/apache/arrow-rs/pull/6015) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([barronw](https://github.com/barronw)) +- Update quick-xml requirement from 0.34.0 to 0.35.0 in /object\_store [\#5983](https://github.com/apache/arrow-rs/pull/5983) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Automatically cleanup empty dirs in LocalFileSystem [\#5978](https://github.com/apache/arrow-rs/pull/5978) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([fsdvh](https://github.com/fsdvh)) +- WriteMultipart Abort on MultipartUpload::complete Error [\#5974](https://github.com/apache/arrow-rs/pull/5974) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([fsdvh](https://github.com/fsdvh)) +- Update quick-xml requirement from 0.33.0 to 0.34.0 in /object\_store [\#5954](https://github.com/apache/arrow-rs/pull/5954) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update quick-xml requirement from 0.32.0 to 0.33.0 in /object\_store [\#5946](https://github.com/apache/arrow-rs/pull/5946) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add `MultipartUpload` blanket implementation for `Box` [\#5919](https://github.com/apache/arrow-rs/pull/5919) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([fsdvh](https://github.com/fsdvh)) +- Add user defined metadata [\#5915](https://github.com/apache/arrow-rs/pull/5915) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([criccomini](https://github.com/criccomini)) +- Make ObjectStoreScheme public [\#5912](https://github.com/apache/arrow-rs/pull/5912) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([orf](https://github.com/orf)) +- chore: Remove not used cache in InstanceCredentialProvider [\#5888](https://github.com/apache/arrow-rs/pull/5888) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Xuanwo](https://github.com/Xuanwo)) +- Fix clippy for object\_store [\#5883](https://github.com/apache/arrow-rs/pull/5883) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Update quick-xml requirement from 0.31.0 to 0.32.0 in /object\_store [\#5870](https://github.com/apache/arrow-rs/pull/5870) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat\(object\_store\): Add `put` API for buffered::BufWriter [\#5835](https://github.com/apache/arrow-rs/pull/5835) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Xuanwo](https://github.com/Xuanwo)) +- Fix 5592: Colon \(:\) in in object\_store::path::{Path} is not handled on Windows [\#5830](https://github.com/apache/arrow-rs/pull/5830) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([hesampakdaman](https://github.com/hesampakdaman)) +- Fix issue \#5800: Handle missing files in list\_with\_delimiter [\#5803](https://github.com/apache/arrow-rs/pull/5803) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([hesampakdaman](https://github.com/hesampakdaman)) +- Update nix requirement from 0.28.0 to 0.29.0 in /object\_store [\#5799](https://github.com/apache/arrow-rs/pull/5799) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update itertools requirement from 0.12.0 to 0.13.0 in /object\_store [\#5780](https://github.com/apache/arrow-rs/pull/5780) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add additional WriteMultipart tests \(\#5743\) [\#5746](https://github.com/apache/arrow-rs/pull/5746) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + + + +\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* + +## [object_store_0.10.1](https://github.com/apache/arrow-rs/tree/object_store_0.10.1) (2024-05-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.10.0...object_store_0.10.1) + +**Implemented enhancements:** + +- Allow specifying PUT options when using `BufWriter` [\#5692](https://github.com/apache/arrow-rs/issues/5692) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add more attributes to `object_store::Attribute` [\#5689](https://github.com/apache/arrow-rs/issues/5689) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- feat object\_store: moving tests from src/ to a tests/ folder and enabling access to test functions for enabling a shared integration test suite [\#5685](https://github.com/apache/arrow-rs/issues/5685) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Release Object Store 0.10.0 [\#5647](https://github.com/apache/arrow-rs/issues/5647) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- Using WriteMultipart::put results in 0 bytes being written [\#5743](https://github.com/apache/arrow-rs/issues/5743) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Fix PutPayloadMut::push not updating content\_length \(\#5743\) [\#5744](https://github.com/apache/arrow-rs/pull/5744) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Export object\_store integration tests [\#5709](https://github.com/apache/arrow-rs/pull/5709) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add `BufWriter::with_attributes` and `::with_tags` in `object_store` [\#5693](https://github.com/apache/arrow-rs/pull/5693) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([netthier](https://github.com/netthier)) +- Add more attributes to `object_store::Attribute` [\#5690](https://github.com/apache/arrow-rs/pull/5690) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([netthier](https://github.com/netthier)) + + +## [object_store_0.10.0](https://github.com/apache/arrow-rs/tree/object_store_0.10.0) (2024-04-17) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.9.1...object_store_0.10.0) + +**Breaking changes:** + +- Add put\_multipart\_opts \(\#5435\) [\#5652](https://github.com/apache/arrow-rs/pull/5652) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add Attributes API \(\#5329\) [\#5650](https://github.com/apache/arrow-rs/pull/5650) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Support non-contiguous put payloads / vectored writes \(\#5514\) [\#5538](https://github.com/apache/arrow-rs/pull/5538) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Replace AsyncWrite with Upload trait and rename MultiPartStore to MultipartStore \(\#5458\) [\#5500](https://github.com/apache/arrow-rs/pull/5500) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Improve Retry Coverage [\#5608](https://github.com/apache/arrow-rs/issues/5608) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Zero Copy Support [\#5593](https://github.com/apache/arrow-rs/issues/5593) +- ObjectStore bulk delete [\#5591](https://github.com/apache/arrow-rs/issues/5591) +- Retry on Broken Connection [\#5589](https://github.com/apache/arrow-rs/issues/5589) +- Inconsistent Multipart Nomenclature [\#5526](https://github.com/apache/arrow-rs/issues/5526) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[ObjectStore\] Non-Contiguous Write Payloads [\#5514](https://github.com/apache/arrow-rs/issues/5514) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- In Object Store, return version & etag on multipart put. [\#5443](https://github.com/apache/arrow-rs/issues/5443) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Release Object Store 0.9.1 [\#5436](https://github.com/apache/arrow-rs/issues/5436) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: allow setting content-type per request [\#5329](https://github.com/apache/arrow-rs/issues/5329) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- GCS Signed URL Support [\#5233](https://github.com/apache/arrow-rs/issues/5233) + +**Fixed bugs:** + +- \[object\_store\] minor bug: typos present in local variable [\#5628](https://github.com/apache/arrow-rs/issues/5628) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[arrow-csv\] Schema inference requires csv on disk [\#5551](https://github.com/apache/arrow-rs/issues/5551) +- Local object store copy/rename with nonexistent `from` file loops forever instead of erroring [\#5503](https://github.com/apache/arrow-rs/issues/5503) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object store ApplicationDefaultCredentials auth is not working on windows [\#5466](https://github.com/apache/arrow-rs/issues/5466) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- MicrosoftAzure store list result omits empty objects [\#5451](https://github.com/apache/arrow-rs/issues/5451) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Documentation updates:** + +- Minor: add additional documentation about `BufWriter` [\#5519](https://github.com/apache/arrow-rs/pull/5519) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- minor-fix: removed typos in object\_store sub crate [\#5629](https://github.com/apache/arrow-rs/pull/5629) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Silemo](https://github.com/Silemo)) +- Retry on More Error Classes [\#5609](https://github.com/apache/arrow-rs/pull/5609) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([andrebsguedes](https://github.com/andrebsguedes)) +- Fix handling of empty multipart uploads for GCS [\#5590](https://github.com/apache/arrow-rs/pull/5590) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Upgrade object\_store dependency to use chrono `0.4.34` [\#5578](https://github.com/apache/arrow-rs/pull/5578) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([l1nxy](https://github.com/l1nxy)) +- Fix Latest Clippy Lints for object\_store [\#5546](https://github.com/apache/arrow-rs/pull/5546) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update reqwest 0.12 and http 1.0 [\#5536](https://github.com/apache/arrow-rs/pull/5536) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Implement MultipartStore for ThrottledStore [\#5533](https://github.com/apache/arrow-rs/pull/5533) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- fix: copy/rename return error if source is nonexistent [\#5528](https://github.com/apache/arrow-rs/pull/5528) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dimbtp](https://github.com/dimbtp)) +- Prepare arrow 51.0.0 [\#5516](https://github.com/apache/arrow-rs/pull/5516) ([tustvold](https://github.com/tustvold)) +- Implement MultiPartStore for InMemory [\#5495](https://github.com/apache/arrow-rs/pull/5495) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add more comprehensive documentation on testing and benchmarking to CONTRIBUTING.md [\#5478](https://github.com/apache/arrow-rs/pull/5478) ([monkwire](https://github.com/monkwire)) +- add support for gcp application default auth on windows in object store [\#5473](https://github.com/apache/arrow-rs/pull/5473) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Itayazolay](https://github.com/Itayazolay)) +- Update base64 requirement from 0.21 to 0.22 in /object\_store [\#5465](https://github.com/apache/arrow-rs/pull/5465) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Uses ResourceType for filtering list directories instead of workaround [\#5452](https://github.com/apache/arrow-rs/pull/5452) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([andrebsguedes](https://github.com/andrebsguedes)) +- Add GCS signed URL support [\#5300](https://github.com/apache/arrow-rs/pull/5300) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([l1nxy](https://github.com/l1nxy)) + +## [object_store_0.9.1](https://github.com/apache/arrow-rs/tree/object_store_0.9.1) (2024-03-01) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.9.0...object_store_0.9.1) + +**Implemented enhancements:** + +- \[object\_store\] Enable anonymous read access for Azure [\#5424](https://github.com/apache/arrow-rs/issues/5424) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support for additional URL formats in object\_store for Azure blob [\#5370](https://github.com/apache/arrow-rs/issues/5370) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Mention "Http" support in README [\#5320](https://github.com/apache/arrow-rs/issues/5320) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Pass Options to HttpBuilder in parse\_url\_opts [\#5310](https://github.com/apache/arrow-rs/issues/5310) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Remove Localstack DynamoDB Workaround Once Fixed Upstream [\#5267](https://github.com/apache/arrow-rs/issues/5267) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Can I use S3 server side encryption [\#5087](https://github.com/apache/arrow-rs/issues/5087) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- delete\_stream fails in MinIO [\#5414](https://github.com/apache/arrow-rs/issues/5414) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object\_store\] Completing an empty Multipart Upload fails for AWS S3 [\#5404](https://github.com/apache/arrow-rs/issues/5404) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Multipart upload can leave futures unpolled, leading to timeout [\#5366](https://github.com/apache/arrow-rs/issues/5366) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Broken Link in README \(Rust Object Store\) Content [\#5309](https://github.com/apache/arrow-rs/issues/5309) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Expose path\_to\_filesystem public [\#5441](https://github.com/apache/arrow-rs/pull/5441) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([metesynnada](https://github.com/metesynnada)) +- Update nix requirement from 0.27.1 to 0.28.0 in /object\_store [\#5432](https://github.com/apache/arrow-rs/pull/5432) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add BufWriter for Adapative Put / Multipart Upload [\#5431](https://github.com/apache/arrow-rs/pull/5431) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Enable anonymous access for MicrosoftAzure [\#5425](https://github.com/apache/arrow-rs/pull/5425) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([andrebsguedes](https://github.com/andrebsguedes)) +- fix\(object\_store\): Include Content-MD5 header for S3 DeleteObjects [\#5415](https://github.com/apache/arrow-rs/pull/5415) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([paraseba](https://github.com/paraseba)) +- docds\(object\_store\): Mention HTTP/WebDAV in README [\#5409](https://github.com/apache/arrow-rs/pull/5409) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Xuanwo](https://github.com/Xuanwo)) +- \[object\_store\] Fix empty Multipart Upload for AWS S3 [\#5405](https://github.com/apache/arrow-rs/pull/5405) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([andrebsguedes](https://github.com/andrebsguedes)) +- feat: S3 server-side encryption [\#5402](https://github.com/apache/arrow-rs/pull/5402) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- Pull container name from URL for Azure blob [\#5371](https://github.com/apache/arrow-rs/pull/5371) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([bradvoth](https://github.com/bradvoth)) +- docs\(object-store\): add warning to flush [\#5369](https://github.com/apache/arrow-rs/pull/5369) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- Minor\(docs\): update master to main for DataFusion/Ballista [\#5363](https://github.com/apache/arrow-rs/pull/5363) ([caicancai](https://github.com/caicancai)) +- Test parse\_url\_opts for HTTP \(\#5310\) [\#5316](https://github.com/apache/arrow-rs/pull/5316) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update IOx links [\#5312](https://github.com/apache/arrow-rs/pull/5312) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Pass options to HTTPBuilder in parse\_url\_opts \(\#5310\) [\#5311](https://github.com/apache/arrow-rs/pull/5311) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Bump actions/cache from 3 to 4 [\#5308](https://github.com/apache/arrow-rs/pull/5308) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Remove localstack DynamoDB workaround \(\#5267\) [\#5307](https://github.com/apache/arrow-rs/pull/5307) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- refactor: log server error during object store retries [\#5294](https://github.com/apache/arrow-rs/pull/5294) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- Prepare arrow 50.0.0 [\#5291](https://github.com/apache/arrow-rs/pull/5291) ([tustvold](https://github.com/tustvold)) +- Enable JS tests again [\#5287](https://github.com/apache/arrow-rs/pull/5287) ([domoritz](https://github.com/domoritz)) + +## [object_store_0.9.0](https://github.com/apache/arrow-rs/tree/object_store_0.9.0) (2024-01-05) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.8.0...object_store_0.9.0) + +**Breaking changes:** + +- Remove deprecated try\_with\_option methods [\#5237](https://github.com/apache/arrow-rs/pull/5237) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- object\_store: full HTTP range support [\#5222](https://github.com/apache/arrow-rs/pull/5222) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([clbarnes](https://github.com/clbarnes)) +- feat\(object\_store\): use http1 by default [\#5204](https://github.com/apache/arrow-rs/pull/5204) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- refactor: change `object_store` CA handling [\#5056](https://github.com/apache/arrow-rs/pull/5056) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) + +**Implemented enhancements:** + +- Azure Signed URL Support [\#5232](https://github.com/apache/arrow-rs/issues/5232) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object-store\] Make aws region optional. [\#5211](https://github.com/apache/arrow-rs/issues/5211) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object\_store,gcp\] Document GoogleCloudStorage Default Credentials [\#5187](https://github.com/apache/arrow-rs/issues/5187) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support S3 Express One Zone [\#5140](https://github.com/apache/arrow-rs/issues/5140) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- `object_store`: Allow 403 Forbidden for `copy_if_not_exists` S3 status code [\#5132](https://github.com/apache/arrow-rs/issues/5132) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add `copy_if_not_exists` support for AmazonS3 via DynamoDB Lock Support [\#4880](https://github.com/apache/arrow-rs/issues/4880) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: native certs, w/o webpki-roots [\#4870](https://github.com/apache/arrow-rs/issues/4870) +- object\_store: range request with suffix [\#4611](https://github.com/apache/arrow-rs/issues/4611) + +**Fixed bugs:** + +- ObjectStore::get\_opts Incorrectly Returns Response Size not Object Size [\#5272](https://github.com/apache/arrow-rs/issues/5272) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Single object store has limited throughput on GCS [\#5194](https://github.com/apache/arrow-rs/issues/5194) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- local::tests::invalid\_path fails during object store release verification [\#5035](https://github.com/apache/arrow-rs/issues/5035) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Object Store Doctest Failure with Default Features [\#5025](https://github.com/apache/arrow-rs/issues/5025) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Documentation updates:** + +- Document default value of InstanceCredentialProvider [\#5188](https://github.com/apache/arrow-rs/pull/5188) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([justinabrahms](https://github.com/justinabrahms)) + +**Merged pull requests:** + +- Retry Safe/Read-Only Requests on Timeout [\#5278](https://github.com/apache/arrow-rs/pull/5278) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix ObjectMeta::size for range requests \(\#5272\) [\#5276](https://github.com/apache/arrow-rs/pull/5276) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- docs\(object\_store\): Mention `with_allow_http` in docs of `with_endpoint` [\#5275](https://github.com/apache/arrow-rs/pull/5275) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Xuanwo](https://github.com/Xuanwo)) +- Support S3 Express One Zone [\#5268](https://github.com/apache/arrow-rs/pull/5268) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- feat\(object\_store\): Azure url signing [\#5259](https://github.com/apache/arrow-rs/pull/5259) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- DynamoDB ConditionalPut [\#5247](https://github.com/apache/arrow-rs/pull/5247) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Default AWS region to us-east-1 \(\#5211\) [\#5244](https://github.com/apache/arrow-rs/pull/5244) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- ci: Fail Miri CI on first failure [\#5243](https://github.com/apache/arrow-rs/pull/5243) ([Jefffrey](https://github.com/Jefffrey)) +- Bump actions/upload-pages-artifact from 2 to 3 [\#5229](https://github.com/apache/arrow-rs/pull/5229) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/setup-python from 4 to 5 [\#5175](https://github.com/apache/arrow-rs/pull/5175) ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: ensure take\_fixed\_size\_list can handle null indices [\#5170](https://github.com/apache/arrow-rs/pull/5170) ([westonpace](https://github.com/westonpace)) +- Bump actions/labeler from 4.3.0 to 5.0.0 [\#5167](https://github.com/apache/arrow-rs/pull/5167) ([dependabot[bot]](https://github.com/apps/dependabot)) +- object\_store: fix failing doctest with default features [\#5161](https://github.com/apache/arrow-rs/pull/5161) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Jefffrey](https://github.com/Jefffrey)) +- Update rustls-pemfile requirement from 1.0 to 2.0 in /object\_store [\#5155](https://github.com/apache/arrow-rs/pull/5155) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow 403 for overwrite prevention [\#5134](https://github.com/apache/arrow-rs/pull/5134) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([emcake](https://github.com/emcake)) +- Fix ObjectStore.LocalFileSystem.put\_opts for blobfuse [\#5094](https://github.com/apache/arrow-rs/pull/5094) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([RobinLin666](https://github.com/RobinLin666)) +- Update itertools requirement from 0.11.0 to 0.12.0 in /object\_store [\#5077](https://github.com/apache/arrow-rs/pull/5077) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add a PR under "Breaking changes" in the object\_store 0.8.0 changelog [\#5063](https://github.com/apache/arrow-rs/pull/5063) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([carols10cents](https://github.com/carols10cents)) +- Prepare arrow 49.0.0 [\#5054](https://github.com/apache/arrow-rs/pull/5054) ([tustvold](https://github.com/tustvold)) +- Fix invalid\_path test [\#5026](https://github.com/apache/arrow-rs/pull/5026) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Implement `copy_if_not_exist` for `AmazonS3` using DynamoDB \(\#4880\) [\#4918](https://github.com/apache/arrow-rs/pull/4918) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +## [object_store_0.8.0](https://github.com/apache/arrow-rs/tree/object_store_0.8.0) (2023-11-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.7.1...object_store_0.8.0) + +**Breaking changes:** + +- Remove ObjectStore::append [\#5016](https://github.com/apache/arrow-rs/pull/5016) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Don't panic on invalid Azure access key \(\#4972\) [\#4974](https://github.com/apache/arrow-rs/pull/4974) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Return `PutResult` with an ETag from ObjectStore::put \(\#4934\) [\#4944](https://github.com/apache/arrow-rs/pull/4944) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add ObjectMeta::version and GetOptions::version \(\#4925\) [\#4935](https://github.com/apache/arrow-rs/pull/4935) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add GetOptions::head [\#4931](https://github.com/apache/arrow-rs/pull/4931) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Remove Nested async and Fallibility from ObjectStore::list [\#4930](https://github.com/apache/arrow-rs/pull/4930) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add ObjectStore::put_opts / Conditional Put [\#4879](https://github.com/apache/arrow-rs/pull/4984) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Relax Path Safety on Parse [\#5019](https://github.com/apache/arrow-rs/issues/5019) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- ObjectStore: hard to determine the cause of the error thrown from retry [\#5013](https://github.com/apache/arrow-rs/issues/5013) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- continue existing multi-part upload [\#4961](https://github.com/apache/arrow-rs/issues/4961) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Simplify ObjectStore::List [\#4946](https://github.com/apache/arrow-rs/issues/4946) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Return ETag and Version on Put [\#4934](https://github.com/apache/arrow-rs/issues/4934) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support Not Signing Requests in AmazonS3 [\#4927](https://github.com/apache/arrow-rs/issues/4927) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Get Object By Version [\#4925](https://github.com/apache/arrow-rs/issues/4925) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Plans for supporting Extension Array to support Fixed shape tensor Array [\#4890](https://github.com/apache/arrow-rs/issues/4890) +- Conditional Put Support [\#4879](https://github.com/apache/arrow-rs/issues/4879) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- creates\_dir\_if\_not\_present\_append Test is Flaky [\#4872](https://github.com/apache/arrow-rs/issues/4872) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Release object\_store `0.7.1` [\#4858](https://github.com/apache/arrow-rs/issues/4858) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support User-Defined Object Metadata [\#4754](https://github.com/apache/arrow-rs/issues/4754) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- APIs for directly managing multi-part uploads and saving potential parquet footers [\#4608](https://github.com/apache/arrow-rs/issues/4608) + +**Fixed bugs:** + +- ObjectStore parse\_url Incorrectly Handles URLs with Spaces [\#5017](https://github.com/apache/arrow-rs/issues/5017) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[objects-store\]: periods/dots error in GCP bucket [\#4991](https://github.com/apache/arrow-rs/issues/4991) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Azure ImdsManagedIdentityProvider does not work in Azure functions [\#4976](https://github.com/apache/arrow-rs/issues/4976) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Panic when using an azure object store with an invalid access key [\#4972](https://github.com/apache/arrow-rs/issues/4972) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Handle Body Errors in AWS CompleteMultipartUpload [\#4965](https://github.com/apache/arrow-rs/issues/4965) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- ObjectStore multiple\_append Test is Flaky [\#4868](https://github.com/apache/arrow-rs/issues/4868) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[objectstore\] Problem with special characters in file path [\#4454](https://github.com/apache/arrow-rs/issues/4454) + +**Closed issues:** + +- Include onelake fabric path for https [\#5000](https://github.com/apache/arrow-rs/issues/5000) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object\_store\] Support generating and using signed upload URLs [\#4763](https://github.com/apache/arrow-rs/issues/4763) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Relax path safety \(\#5019\) [\#5020](https://github.com/apache/arrow-rs/pull/5020) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Decode URL paths \(\#5017\) [\#5018](https://github.com/apache/arrow-rs/pull/5018) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- ObjectStore: make error msg thrown from retry more detailed [\#5012](https://github.com/apache/arrow-rs/pull/5012) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Rachelint](https://github.com/Rachelint)) +- Support onelake fabric paths in parse\_url \(\#5000\) [\#5002](https://github.com/apache/arrow-rs/pull/5002) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Object tagging \(\#4754\) [\#4999](https://github.com/apache/arrow-rs/pull/4999) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- \[MINOR\] No need to jump to web pages [\#4994](https://github.com/apache/arrow-rs/pull/4994) ([smallzhongfeng](https://github.com/smallzhongfeng)) +- Pushdown list\_with\_offset for GCS [\#4993](https://github.com/apache/arrow-rs/pull/4993) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Support bucket name with `.` when parsing GCS URL \(\#4991\) [\#4992](https://github.com/apache/arrow-rs/pull/4992) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Increase default timeout to 30 seconds [\#4989](https://github.com/apache/arrow-rs/pull/4989) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Conditional Put \(\#4879\) [\#4984](https://github.com/apache/arrow-rs/pull/4984) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update quick-xml requirement from 0.30.0 to 0.31.0 in /object\_store [\#4983](https://github.com/apache/arrow-rs/pull/4983) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/setup-node from 3 to 4 [\#4982](https://github.com/apache/arrow-rs/pull/4982) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support ImdsManagedIdentityProvider in Azure Functions \(\#4976\) [\#4977](https://github.com/apache/arrow-rs/pull/4977) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add MultiPartStore \(\#4961\) \(\#4608\) [\#4971](https://github.com/apache/arrow-rs/pull/4971) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Split gcp Module [\#4956](https://github.com/apache/arrow-rs/pull/4956) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add module links in docs root [\#4955](https://github.com/apache/arrow-rs/pull/4955) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Prepare arrow 48.0.0 [\#4948](https://github.com/apache/arrow-rs/pull/4948) ([tustvold](https://github.com/tustvold)) +- Allow opting out of request signing \(\#4927\) [\#4929](https://github.com/apache/arrow-rs/pull/4929) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Default connection and request timeouts of 5 seconds [\#4928](https://github.com/apache/arrow-rs/pull/4928) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Support service\_account in ApplicationDefaultCredentials and Use SelfSignedJwt [\#4926](https://github.com/apache/arrow-rs/pull/4926) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Generate `ETag`s for `InMemory` and `LocalFileSystem` \(\#4879\) [\#4922](https://github.com/apache/arrow-rs/pull/4922) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Cleanup `object_store::retry` client error handling [\#4915](https://github.com/apache/arrow-rs/pull/4915) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix integration tests [\#4889](https://github.com/apache/arrow-rs/pull/4889) ([tustvold](https://github.com/tustvold)) +- Support Parsing Avro File Headers [\#4888](https://github.com/apache/arrow-rs/pull/4888) ([tustvold](https://github.com/tustvold)) +- Update ring requirement from 0.16 to 0.17 in /object\_store [\#4887](https://github.com/apache/arrow-rs/pull/4887) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add AWS presigned URL support [\#4876](https://github.com/apache/arrow-rs/pull/4876) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([carols10cents](https://github.com/carols10cents)) +- Flush in creates\_dir\_if\_not\_present\_append \(\#4872\) [\#4874](https://github.com/apache/arrow-rs/pull/4874) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Flush in multiple\_append test \(\#4868\) [\#4869](https://github.com/apache/arrow-rs/pull/4869) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Enable new integration tests \(\#4828\) [\#4862](https://github.com/apache/arrow-rs/pull/4862) ([tustvold](https://github.com/tustvold)) + +## [object_store_0.7.1](https://github.com/apache/arrow-rs/tree/object_store_0.7.1) (2023-09-26) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.7.0...object_store_0.7.1) + +**Implemented enhancements:** + +- Automatically Cleanup LocalFileSystem Temporary Files [\#4778](https://github.com/apache/arrow-rs/issues/4778) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object-store: Expose an async reader API for object store [\#4762](https://github.com/apache/arrow-rs/issues/4762) +- Improve proxy support by using reqwest::Proxy as configuration [\#4713](https://github.com/apache/arrow-rs/issues/4713) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- object-store: http shouldn't perform range requests unless `accept-ranges: bytes` header is present [\#4839](https://github.com/apache/arrow-rs/issues/4839) +- object-store: http-store fails when url doesn't have last-modified header on 0.7.0 [\#4831](https://github.com/apache/arrow-rs/issues/4831) +- object-store fails to compile for `wasm32-unknown-unknown` with `http` feature [\#4776](https://github.com/apache/arrow-rs/issues/4776) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object-store: could not find `header` in `client` for `http` feature [\#4775](https://github.com/apache/arrow-rs/issues/4775) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- LocalFileSystem Copy and Rename Don't Create Intermediate Directories [\#4760](https://github.com/apache/arrow-rs/issues/4760) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- LocalFileSystem Copy is not Atomic [\#4758](https://github.com/apache/arrow-rs/issues/4758) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Closed issues:** + +- object\_store Azure Government Cloud functionality? [\#4853](https://github.com/apache/arrow-rs/issues/4853) + +**Merged pull requests:** + +- Add ObjectStore BufReader \(\#4762\) [\#4857](https://github.com/apache/arrow-rs/pull/4857) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Allow overriding azure endpoint [\#4854](https://github.com/apache/arrow-rs/pull/4854) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve object\_store docs.rs landing page [\#4849](https://github.com/apache/arrow-rs/pull/4849) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Error if Remote Ignores HTTP Range Header [\#4841](https://github.com/apache/arrow-rs/pull/4841) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([universalmind303](https://github.com/universalmind303)) +- Perform HEAD request for HttpStore::head [\#4837](https://github.com/apache/arrow-rs/pull/4837) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- fix: object store http header last modified [\#4834](https://github.com/apache/arrow-rs/pull/4834) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([universalmind303](https://github.com/universalmind303)) +- Prepare arrow 47.0.0 [\#4827](https://github.com/apache/arrow-rs/pull/4827) ([tustvold](https://github.com/tustvold)) +- ObjectStore Wasm32 Fixes \(\#4775\) \(\#4776\) [\#4796](https://github.com/apache/arrow-rs/pull/4796) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Best effort cleanup of staged upload files \(\#4778\) [\#4792](https://github.com/apache/arrow-rs/pull/4792) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Relaxing type bounds on coalesce\_ranges and collect\_bytes [\#4787](https://github.com/apache/arrow-rs/pull/4787) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([sumerman](https://github.com/sumerman)) +- Update object\_store chrono deprecations [\#4786](https://github.com/apache/arrow-rs/pull/4786) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Make coalesce\_ranges and collect\_bytes available for crate users [\#4784](https://github.com/apache/arrow-rs/pull/4784) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([sumerman](https://github.com/sumerman)) +- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Make ObjectStore::copy Atomic and Automatically Create Parent Directories \(\#4758\) \(\#4760\) [\#4759](https://github.com/apache/arrow-rs/pull/4759) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update nix requirement from 0.26.1 to 0.27.1 in /object\_store [\#4744](https://github.com/apache/arrow-rs/pull/4744) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([viirya](https://github.com/viirya)) +- Add `with_proxy_ca_certificate` and `with_proxy_excludes` [\#4714](https://github.com/apache/arrow-rs/pull/4714) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([gordonwang0](https://github.com/gordonwang0)) +- Update object\_store Dependencies and Configure Dependabot [\#4700](https://github.com/apache/arrow-rs/pull/4700) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +## [object_store_0.7.0](https://github.com/apache/arrow-rs/tree/object_store_0.7.0) (2023-08-15) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.6.1...object_store_0.7.0) + +**Breaking changes:** + +- Add range and ObjectMeta to GetResult \(\#4352\) \(\#4495\) [\#4677](https://github.com/apache/arrow-rs/pull/4677) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add AzureConfigKey::ContainerName [\#4629](https://github.com/apache/arrow-rs/issues/4629) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: multipart ranges for HTTP [\#4612](https://github.com/apache/arrow-rs/issues/4612) +- Make object\_store::multipart public [\#4569](https://github.com/apache/arrow-rs/issues/4569) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Export `ClientConfigKey` and make the `HttpBuilder` more consistent with other builders [\#4515](https://github.com/apache/arrow-rs/issues/4515) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store/InMemory: Make `clone()` non-async [\#4496](https://github.com/apache/arrow-rs/issues/4496) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add Range to GetResult::File [\#4352](https://github.com/apache/arrow-rs/issues/4352) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support copy\_if\_not\_exists for Cloudflare R2 \(S3 API\) [\#4190](https://github.com/apache/arrow-rs/issues/4190) + +**Fixed bugs:** + +- object\_store documentation is broken [\#4683](https://github.com/apache/arrow-rs/issues/4683) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Exports are not sufficient for configuring some object stores, for example minio running locally [\#4530](https://github.com/apache/arrow-rs/issues/4530) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Uploading empty file to S3 results in "411 Length Required" [\#4514](https://github.com/apache/arrow-rs/issues/4514) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- GCP doesn't fetch public objects [\#4417](https://github.com/apache/arrow-rs/issues/4417) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Closed issues:** + +- \[object\_store\] when Create a AmazonS3 instance work with MinIO without set endpoint got error MissingRegion [\#4617](https://github.com/apache/arrow-rs/issues/4617) +- AWS Profile credentials no longer working in object\_store 0.6.1 [\#4556](https://github.com/apache/arrow-rs/issues/4556) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Add AzureConfigKey::ContainerName \(\#4629\) [\#4686](https://github.com/apache/arrow-rs/pull/4686) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix MSRV CI [\#4671](https://github.com/apache/arrow-rs/pull/4671) ([tustvold](https://github.com/tustvold)) +- Use Config System for Object Store Integration Tests [\#4628](https://github.com/apache/arrow-rs/pull/4628) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Prepare arrow 45 [\#4590](https://github.com/apache/arrow-rs/pull/4590) ([tustvold](https://github.com/tustvold)) +- Add Support for Microsoft Fabric / OneLake [\#4573](https://github.com/apache/arrow-rs/pull/4573) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([vmuddassir-msft](https://github.com/vmuddassir-msft)) +- Cleanup multipart upload trait [\#4572](https://github.com/apache/arrow-rs/pull/4572) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Make object\_store::multipart public [\#4570](https://github.com/apache/arrow-rs/pull/4570) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([yjshen](https://github.com/yjshen)) +- Handle empty S3 payloads \(\#4514\) [\#4518](https://github.com/apache/arrow-rs/pull/4518) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- object\_store: Export `ClientConfigKey` and add `HttpBuilder::with_config` [\#4516](https://github.com/apache/arrow-rs/pull/4516) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([thehabbos007](https://github.com/thehabbos007)) +- object\_store: Implement `ObjectStore` for `Arc` [\#4502](https://github.com/apache/arrow-rs/pull/4502) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- object\_store/InMemory: Add `fork()` fn and deprecate `clone()` fn [\#4499](https://github.com/apache/arrow-rs/pull/4499) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- Bump actions/deploy-pages from 1 to 2 [\#4449](https://github.com/apache/arrow-rs/pull/4449) ([dependabot[bot]](https://github.com/apps/dependabot)) +- gcp: Exclude authorization header when bearer empty [\#4418](https://github.com/apache/arrow-rs/pull/4418) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([vrongmeal](https://github.com/vrongmeal)) +- Support copy\_if\_not\_exists for Cloudflare R2 \(\#4190\) [\#4239](https://github.com/apache/arrow-rs/pull/4239) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +## [object_store_0.6.0](https://github.com/apache/arrow-rs/tree/object_store_0.6.0) (2023-05-18) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.6...object_store_0.6.0) + +**Breaking changes:** + +- Add ObjectStore::get\_opts \(\#2241\) [\#4212](https://github.com/apache/arrow-rs/pull/4212) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Simplify ObjectStore configuration pattern [\#4189](https://github.com/apache/arrow-rs/pull/4189) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- object\_store: fix: Incorrect parsing of https Path Style S3 url [\#4082](https://github.com/apache/arrow-rs/pull/4082) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- feat: add etag for objectMeta [\#3937](https://github.com/apache/arrow-rs/pull/3937) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Weijun-H](https://github.com/Weijun-H)) + +**Implemented enhancements:** + +- Object Store Authorization [\#4223](https://github.com/apache/arrow-rs/issues/4223) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Use XML API for GCS [\#4209](https://github.com/apache/arrow-rs/issues/4209) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- ObjectStore with\_url Should Handle Path [\#4199](https://github.com/apache/arrow-rs/issues/4199) +- Return Error on Invalid Config Value [\#4191](https://github.com/apache/arrow-rs/issues/4191) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Extensible ObjectStore Authentication [\#4163](https://github.com/apache/arrow-rs/issues/4163) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: When using an AWS profile, obtain the default AWS region from the active profile [\#4158](https://github.com/apache/arrow-rs/issues/4158) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- InMemory append API [\#4152](https://github.com/apache/arrow-rs/issues/4152) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support accessing ipc Reader/Writer inner by reference [\#4121](https://github.com/apache/arrow-rs/issues/4121) +- \[object\_store\] Retry requests on connection error [\#4119](https://github.com/apache/arrow-rs/issues/4119) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Instantiate object store from provided url with store options [\#4047](https://github.com/apache/arrow-rs/issues/4047) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Builders \(S3/Azure/GCS\) are missing the `get method` to get the actual configuration information [\#4021](https://github.com/apache/arrow-rs/issues/4021) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- ObjectStore::head Returns Directory for LocalFileSystem and Hierarchical Azure [\#4230](https://github.com/apache/arrow-rs/issues/4230) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: different behavior from aws cli for default profile [\#4137](https://github.com/apache/arrow-rs/issues/4137) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- ImdsManagedIdentityOAuthProvider should send resource ID instead of OIDC scope [\#4096](https://github.com/apache/arrow-rs/issues/4096) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Update readme to remove reference to Jira [\#4091](https://github.com/apache/arrow-rs/issues/4091) +- object\_store: Incorrect parsing of https Path Style S3 url [\#4078](https://github.com/apache/arrow-rs/issues/4078) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- \[object\_store\] `local::tests::test_list_root` test fails during release verification [\#3772](https://github.com/apache/arrow-rs/issues/3772) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Remove AWS\_PROFILE support [\#4238](https://github.com/apache/arrow-rs/pull/4238) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Expose AwsAuthorizer [\#4237](https://github.com/apache/arrow-rs/pull/4237) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Expose CredentialProvider [\#4235](https://github.com/apache/arrow-rs/pull/4235) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Return NotFound for directories in Head and Get \(\#4230\) [\#4231](https://github.com/apache/arrow-rs/pull/4231) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Standardise credentials API \(\#4223\) \(\#4163\) [\#4225](https://github.com/apache/arrow-rs/pull/4225) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Extract Common Listing and Retrieval Functionality [\#4220](https://github.com/apache/arrow-rs/pull/4220) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- feat\(object-store\): extend Options API for http client [\#4208](https://github.com/apache/arrow-rs/pull/4208) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- Consistently use GCP XML API [\#4207](https://github.com/apache/arrow-rs/pull/4207) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Implement list\_with\_offset for PrefixStore [\#4203](https://github.com/apache/arrow-rs/pull/4203) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Allow setting ClientOptions with Options API [\#4202](https://github.com/apache/arrow-rs/pull/4202) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Create ObjectStore from URL and Options \(\#4047\) [\#4200](https://github.com/apache/arrow-rs/pull/4200) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Skip test\_list\_root on OS X \(\#3772\) [\#4198](https://github.com/apache/arrow-rs/pull/4198) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Recognise R2 URLs for S3 object store \(\#4190\) [\#4194](https://github.com/apache/arrow-rs/pull/4194) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix ImdsManagedIdentityProvider \(\#4096\) [\#4193](https://github.com/apache/arrow-rs/pull/4193) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Deffered Object Store Config Parsing \(\#4191\) [\#4192](https://github.com/apache/arrow-rs/pull/4192) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Object Store \(AWS\): Support dynamically resolving S3 bucket region [\#4188](https://github.com/apache/arrow-rs/pull/4188) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mr-brobot](https://github.com/mr-brobot)) +- Faster prefix match in object\_store path handling [\#4164](https://github.com/apache/arrow-rs/pull/4164) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Object Store \(AWS\): Support region configured via named profile [\#4161](https://github.com/apache/arrow-rs/pull/4161) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mr-brobot](https://github.com/mr-brobot)) +- InMemory append API [\#4153](https://github.com/apache/arrow-rs/pull/4153) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([berkaysynnada](https://github.com/berkaysynnada)) +- docs: fix the wrong ln command in CONTRIBUTING.md [\#4139](https://github.com/apache/arrow-rs/pull/4139) ([SteveLauC](https://github.com/SteveLauC)) +- Display the file path in the error message when failed to open credentials file for GCS [\#4124](https://github.com/apache/arrow-rs/pull/4124) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([haoxins](https://github.com/haoxins)) +- Retry on Connection Errors [\#4120](https://github.com/apache/arrow-rs/pull/4120) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kindly](https://github.com/kindly)) +- Simplify reference to GitHub issues [\#4092](https://github.com/apache/arrow-rs/pull/4092) ([bkmgit](https://github.com/bkmgit)) +- Use reqwest build\_split [\#4039](https://github.com/apache/arrow-rs/pull/4039) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix object\_store CI [\#4037](https://github.com/apache/arrow-rs/pull/4037) ([tustvold](https://github.com/tustvold)) +- Add get\_config\_value to AWS/Azure/GCP Builders [\#4035](https://github.com/apache/arrow-rs/pull/4035) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([r4ntix](https://github.com/r4ntix)) +- Update AWS SDK [\#3993](https://github.com/apache/arrow-rs/pull/3993) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +## [object_store_0.5.6](https://github.com/apache/arrow-rs/tree/object_store_0.5.6) (2023-03-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.5...object_store_0.5.6) + +**Implemented enhancements:** + +- Document ObjectStore::list Ordering [\#3975](https://github.com/apache/arrow-rs/issues/3975) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add option to start listing at a particular key [\#3970](https://github.com/apache/arrow-rs/issues/3970) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Implement `ObjectStore` for trait objects [\#3865](https://github.com/apache/arrow-rs/issues/3865) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add ObjectStore::append [\#3790](https://github.com/apache/arrow-rs/issues/3790) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Make `InMemory` object store track last modified time for each entry [\#3782](https://github.com/apache/arrow-rs/issues/3782) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support Unsigned S3 Payloads [\#3737](https://github.com/apache/arrow-rs/issues/3737) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add Content-MD5 or checksum header for using an Object Locked S3 [\#3725](https://github.com/apache/arrow-rs/issues/3725) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- LocalFileSystem::put is not Atomic [\#3780](https://github.com/apache/arrow-rs/issues/3780) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Add ObjectStore::list\_with\_offset \(\#3970\) [\#3973](https://github.com/apache/arrow-rs/pull/3973) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Remove incorrect validation logic on S3 bucket names [\#3947](https://github.com/apache/arrow-rs/pull/3947) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([rtyler](https://github.com/rtyler)) +- Prepare arrow 36 [\#3935](https://github.com/apache/arrow-rs/pull/3935) ([tustvold](https://github.com/tustvold)) +- fix: Specify content length for gcp copy request [\#3921](https://github.com/apache/arrow-rs/pull/3921) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([scsmithr](https://github.com/scsmithr)) +- Revert structured ArrayData \(\#3877\) [\#3894](https://github.com/apache/arrow-rs/pull/3894) ([tustvold](https://github.com/tustvold)) +- Add support for checksum algorithms in AWS [\#3873](https://github.com/apache/arrow-rs/pull/3873) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([trueleo](https://github.com/trueleo)) +- Rename PrefixObjectStore to PrefixStore [\#3870](https://github.com/apache/arrow-rs/pull/3870) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Implement append for LimitStore, PrefixObjectStore, ThrottledStore [\#3869](https://github.com/apache/arrow-rs/pull/3869) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Supporting metadata fetch without open file read mode [\#3868](https://github.com/apache/arrow-rs/pull/3868) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([metesynnada](https://github.com/metesynnada)) +- Impl ObjectStore for trait object [\#3866](https://github.com/apache/arrow-rs/pull/3866) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Kinrany](https://github.com/Kinrany)) +- Update quick-xml requirement from 0.27.0 to 0.28.0 [\#3857](https://github.com/apache/arrow-rs/pull/3857) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update changelog for 35.0.0 [\#3843](https://github.com/apache/arrow-rs/pull/3843) ([tustvold](https://github.com/tustvold)) +- Cleanup ApplicationDefaultCredentials [\#3799](https://github.com/apache/arrow-rs/pull/3799) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Make InMemory object store track last modified time for each entry [\#3796](https://github.com/apache/arrow-rs/pull/3796) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Weijun-H](https://github.com/Weijun-H)) +- Add ObjectStore::append [\#3791](https://github.com/apache/arrow-rs/pull/3791) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Make LocalFileSystem::put atomic \(\#3780\) [\#3781](https://github.com/apache/arrow-rs/pull/3781) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add support for unsigned payloads in aws [\#3741](https://github.com/apache/arrow-rs/pull/3741) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([trueleo](https://github.com/trueleo)) + +## [object_store_0.5.5](https://github.com/apache/arrow-rs/tree/object_store_0.5.5) (2023-02-27) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.4...object_store_0.5.5) + +**Implemented enhancements:** + +- object\_store: support azure cli credential [\#3697](https://github.com/apache/arrow-rs/issues/3697) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: support encoded path as input [\#3651](https://github.com/apache/arrow-rs/issues/3651) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- object-store: aws\_profile fails to load static credentials [\#3765](https://github.com/apache/arrow-rs/issues/3765) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Inconsistent Behaviour Listing File [\#3712](https://github.com/apache/arrow-rs/issues/3712) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: bearer token is azure is used like access key [\#3696](https://github.com/apache/arrow-rs/issues/3696) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- object-store: fix handling of AWS profile credentials without expiry [\#3766](https://github.com/apache/arrow-rs/pull/3766) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([helmus](https://github.com/helmus)) +- update object\_store deps to patch potential security vulnerabilities [\#3761](https://github.com/apache/arrow-rs/pull/3761) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([spencerbart](https://github.com/spencerbart)) +- Filter exact list prefix matches for azure gen2 accounts [\#3714](https://github.com/apache/arrow-rs/pull/3714) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- Filter exact list prefix matches for MemoryStore and HttpStore \(\#3712\) [\#3713](https://github.com/apache/arrow-rs/pull/3713) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- object\_store: azure cli authorization [\#3698](https://github.com/apache/arrow-rs/pull/3698) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- object\_store: add Path::from\_url\_path [\#3663](https://github.com/apache/arrow-rs/pull/3663) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jychen7](https://github.com/jychen7)) + +## [object_store_0.5.4](https://github.com/apache/arrow-rs/tree/object_store_0.5.4) (2023-01-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.3...object_store_0.5.4) + +**Implemented enhancements:** + +- \[object\_store\] support more identity based auth flows for azure [\#3580](https://github.com/apache/arrow-rs/issues/3580) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Implement workload identity and application default credentials for GCP object store. [\#3533](https://github.com/apache/arrow-rs/issues/3533) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support GCP Workload Identity [\#3490](https://github.com/apache/arrow-rs/issues/3490) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Allow providing service account key directly when building GCP object store client [\#3488](https://github.com/apache/arrow-rs/issues/3488) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Closed issues:** + +- object\_store: temporary aws credentials not refreshed? [\#3446](https://github.com/apache/arrow-rs/issues/3446) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Final tweaks to 32.0.0 changelog [\#3618](https://github.com/apache/arrow-rs/pull/3618) ([tustvold](https://github.com/tustvold)) +- Update AWS SDK [\#3617](https://github.com/apache/arrow-rs/pull/3617) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add ClientOption.allow\_insecure [\#3600](https://github.com/apache/arrow-rs/pull/3600) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([poelzi](https://github.com/poelzi)) +- \[object\_store\] support azure managed and workload identities [\#3581](https://github.com/apache/arrow-rs/pull/3581) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- Additional GCP authentication [\#3541](https://github.com/apache/arrow-rs/pull/3541) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([winding-lines](https://github.com/winding-lines)) +- Update aws-config and aws-types requirements from 0.52 to 0.53 [\#3539](https://github.com/apache/arrow-rs/pull/3539) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([viirya](https://github.com/viirya)) +- Use GHA concurrency groups \(\#3495\) [\#3538](https://github.com/apache/arrow-rs/pull/3538) ([tustvold](https://github.com/tustvold)) +- Remove azurite test exception [\#3497](https://github.com/apache/arrow-rs/pull/3497) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- feat: Allow providing a service account key directly for GCS [\#3489](https://github.com/apache/arrow-rs/pull/3489) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([scsmithr](https://github.com/scsmithr)) + +## [object_store_0.5.3](https://github.com/apache/arrow-rs/tree/object_store_0.5.3) (2023-01-04) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.2...object_store_0.5.3) + +**Implemented enhancements:** + +- Derive Clone for the builders in object-store. [\#3419](https://github.com/apache/arrow-rs/issues/3419) +- Add a constant prefix object store wrapper [\#3328](https://github.com/apache/arrow-rs/issues/3328) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add support for content-type while uploading files through ObjectStore API [\#3300](https://github.com/apache/arrow-rs/issues/3300) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add HttpStore [\#3294](https://github.com/apache/arrow-rs/issues/3294) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add support for Azure Data Lake Storage Gen2 \(aka: ADLS Gen2\) in Object Store library [\#3283](https://github.com/apache/arrow-rs/issues/3283) +- object\_store: Add Put and Multipart Upload Doc Examples [\#2863](https://github.com/apache/arrow-rs/issues/2863) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Closed issues:** + +- Only flush buffered multi-part data on poll\_shutdown not on poll\_flush [\#3390](https://github.com/apache/arrow-rs/issues/3390) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- object\_store: builder configuration api [\#3436](https://github.com/apache/arrow-rs/pull/3436) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- Derive Clone for ObjectStore builders and Make URL Parsing Stricter \(\#3419\) [\#3424](https://github.com/apache/arrow-rs/pull/3424) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add Put and Multipart Put doc examples [\#3420](https://github.com/apache/arrow-rs/pull/3420) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([GeauxEric](https://github.com/GeauxEric)) +- object\_store: update localstack instructions [\#3403](https://github.com/apache/arrow-rs/pull/3403) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- object\_store: Flush buffered multipart only during poll\_shutdown [\#3397](https://github.com/apache/arrow-rs/pull/3397) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([askoa](https://github.com/askoa)) +- Update quick-xml to 0.27 [\#3395](https://github.com/apache/arrow-rs/pull/3395) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add HttpStore \(\#3294\) [\#3380](https://github.com/apache/arrow-rs/pull/3380) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- add support for content-type in `ClientOptions` [\#3358](https://github.com/apache/arrow-rs/pull/3358) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([ByteBaker](https://github.com/ByteBaker)) +- Update AWS SDK [\#3349](https://github.com/apache/arrow-rs/pull/3349) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Upstream newline\_delimited\_stream and ChunkedStore from DataFusion [\#3341](https://github.com/apache/arrow-rs/pull/3341) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- feat\(object\_store\): add PrefixObjectStore [\#3329](https://github.com/apache/arrow-rs/pull/3329) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- feat\(object\_store\): parse well-known storage urls [\#3327](https://github.com/apache/arrow-rs/pull/3327) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- Disable getrandom object\_store [\#3278](https://github.com/apache/arrow-rs/pull/3278) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Reload token from AWS\_WEB\_IDENTITY\_TOKEN\_FILE [\#3274](https://github.com/apache/arrow-rs/pull/3274) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Minor: skip aws integration test if TEST\_INTEGRATION is not set [\#3262](https://github.com/apache/arrow-rs/pull/3262) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([viirya](https://github.com/viirya)) + +## [object_store_0.5.2](https://github.com/apache/arrow-rs/tree/object_store_0.5.2) (2022-12-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.1...object_store_0.5.2) + +**Implemented enhancements:** + +- Object Store: Allow custom reqwest client [\#3127](https://github.com/apache/arrow-rs/issues/3127) +- socks5 proxy support for the object\_store crate [\#2989](https://github.com/apache/arrow-rs/issues/2989) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Cannot query S3 paths containing whitespace [\#2799](https://github.com/apache/arrow-rs/issues/2799) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- object\_store\(gcp\): GCP complains about content-length for copy [\#3235](https://github.com/apache/arrow-rs/issues/3235) +- object\_store\(aws\): EntityTooSmall error on multi-part upload [\#3233](https://github.com/apache/arrow-rs/issues/3233) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Add more ClientConfig Options for Object Store RequestBuilder \(\#3127\) [\#3256](https://github.com/apache/arrow-rs/pull/3256) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add ObjectStore ClientConfig [\#3252](https://github.com/apache/arrow-rs/pull/3252) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- fix\(object\_store,gcp\): test copy\_if\_not\_exist [\#3236](https://github.com/apache/arrow-rs/pull/3236) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- fix\(object\_store,aws,gcp\): multipart upload enforce size limit of 5 MiB not 5MB [\#3234](https://github.com/apache/arrow-rs/pull/3234) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- object\_store: add support for using proxy\_url for connection testing [\#3109](https://github.com/apache/arrow-rs/pull/3109) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([sum12](https://github.com/sum12)) +- Update AWS SDK [\#2974](https://github.com/apache/arrow-rs/pull/2974) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update quick-xml requirement from 0.25.0 to 0.26.0 [\#2918](https://github.com/apache/arrow-rs/pull/2918) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support building object_store and parquet on wasm32-unknown-unknown target [\#2896](https://github.com/apache/arrow-rs/pull/2899) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jondo2010](https://github.com/jondo2010)) +- Add experimental AWS\_PROFILE support \(\#2178\) [\#2891](https://github.com/apache/arrow-rs/pull/2891) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +## [object_store_0.5.1](https://github.com/apache/arrow-rs/tree/object_store_0.5.1) (2022-10-04) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.0...object_store_0.5.1) + +**Implemented enhancements:** + +- Allow HTTP S3 URLs [\#2806](https://github.com/apache/arrow-rs/issues/2806) +- object\_store: support AWS ECS instance credentials [\#2802](https://github.com/apache/arrow-rs/issues/2802) +- Object Store S3 Alibaba Cloud OSS support [\#2777](https://github.com/apache/arrow-rs/issues/2777) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Expose option to use GCS object store in integration tests [\#2627](https://github.com/apache/arrow-rs/issues/2627) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- S3 Signature Error Performing List With Prefix Containing Spaces [\#2800](https://github.com/apache/arrow-rs/issues/2800) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Erratic Behaviour if Incorrect S3 Region Configured [\#2795](https://github.com/apache/arrow-rs/issues/2795) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Support for overriding instance metadata endpoint [\#2811](https://github.com/apache/arrow-rs/pull/2811) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- Allow Configuring non-TLS HTTP Connections in AmazonS3Builder::from\_env [\#2807](https://github.com/apache/arrow-rs/pull/2807) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Fix S3 query canonicalization \(\#2800\) [\#2801](https://github.com/apache/arrow-rs/pull/2801) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Handle incomplete HTTP redirects missing LOCATION \(\#2795\) [\#2796](https://github.com/apache/arrow-rs/pull/2796) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Handle S3 virtual host request type [\#2782](https://github.com/apache/arrow-rs/pull/2782) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([askoa](https://github.com/askoa)) +- Fix object\_store multipart uploads on S3 Compatible Stores [\#2731](https://github.com/apache/arrow-rs/pull/2731) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([mildbyte](https://github.com/mildbyte)) + + +## [object_store_0.5.0](https://github.com/apache/arrow-rs/tree/object_store_0.5.0) (2022-09-08) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.4.0...object_store_0.5.0) + +**Breaking changes:** + +- Replace azure sdk with custom implementation [\#2509](https://github.com/apache/arrow-rs/pull/2509) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) +- Replace rusoto with custom implementation for AWS \(\#2176\) [\#2352](https://github.com/apache/arrow-rs/pull/2352) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- IMDSv1 Fallback for S3 [\#2609](https://github.com/apache/arrow-rs/issues/2609) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Print Response Body On Error [\#2572](https://github.com/apache/arrow-rs/issues/2572) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Coalesce Ranges Parallel Fetch [\#2562](https://github.com/apache/arrow-rs/issues/2562) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support Coalescing Out-of-Order Ranges [\#2561](https://github.com/apache/arrow-rs/issues/2561) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Add TokenProvider authorization to azure [\#2373](https://github.com/apache/arrow-rs/issues/2373) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- AmazonS3Builder::from\_env to populate credentials from environment [\#2361](https://github.com/apache/arrow-rs/issues/2361) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- AmazonS3 Support IMDSv2 [\#2350](https://github.com/apache/arrow-rs/issues/2350) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- Retry Logic Fails to Retry Server Errors [\#2573](https://github.com/apache/arrow-rs/issues/2573) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Fix multiple part uploads at once making vector size inconsistent [\#2681](https://github.com/apache/arrow-rs/pull/2681) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([gruuya](https://github.com/gruuya)) +- Fix panic in `object_store::util::coalesce_ranges` [\#2554](https://github.com/apache/arrow-rs/pull/2554) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([thinkharderdev](https://github.com/thinkharderdev)) + +**Merged pull requests:** + +- update doc for object\_store copy\_if\_not\_exists [\#2653](https://github.com/apache/arrow-rs/pull/2653) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([JanKaul](https://github.com/JanKaul)) +- Update quick-xml 0.24 [\#2625](https://github.com/apache/arrow-rs/pull/2625) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add IMDSv1 fallback \(\#2609\) [\#2610](https://github.com/apache/arrow-rs/pull/2610) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- ObjectStore cleanup \(\#2587\) [\#2590](https://github.com/apache/arrow-rs/pull/2590) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix retry logic \(\#2573\) \(\#2572\) [\#2574](https://github.com/apache/arrow-rs/pull/2574) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Improve coalesce\_ranges \(\#2561\) \(\#2562\) [\#2563](https://github.com/apache/arrow-rs/pull/2563) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update environment variable name for amazonS3builder in integration \(\#2550\) [\#2553](https://github.com/apache/arrow-rs/pull/2553) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([amrltqt](https://github.com/amrltqt)) +- Build AmazonS3builder from environment variables \(\#2361\) [\#2536](https://github.com/apache/arrow-rs/pull/2536) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([amrltqt](https://github.com/amrltqt)) +- feat: add token provider authorization to azure store [\#2374](https://github.com/apache/arrow-rs/pull/2374) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([roeap](https://github.com/roeap)) + +## [object_store_0.4.0](https://github.com/apache/arrow-rs/tree/object_store_0.4.0) (2022-08-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.3.0...object_store_0.4.0) + +**Implemented enhancements:** + +- Relax Path Validation to Allow Any Percent-Encoded Sequence [\#2355](https://github.com/apache/arrow-rs/issues/2355) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support get\_multi\_ranges in ObjectStore [\#2293](https://github.com/apache/arrow-rs/issues/2293) +- object\_store: Create explicit test for symlinks [\#2206](https://github.com/apache/arrow-rs/issues/2206) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Make builder style configuration for object stores [\#2203](https://github.com/apache/arrow-rs/issues/2203) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Add example in the main documentation readme [\#2202](https://github.com/apache/arrow-rs/issues/2202) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- Azure/S3 Storage Fails to Copy Blob with URL-encoded Path [\#2353](https://github.com/apache/arrow-rs/issues/2353) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Accessing a file with a percent-encoded name on the filesystem with ObjectStore LocalFileSystem [\#2349](https://github.com/apache/arrow-rs/issues/2349) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Documentation updates:** + +- Improve `object_store crate` documentation [\#2260](https://github.com/apache/arrow-rs/pull/2260) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Canonicalize filesystem paths in user-facing APIs \(\#2370\) [\#2371](https://github.com/apache/arrow-rs/pull/2371) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix object\_store lint [\#2367](https://github.com/apache/arrow-rs/pull/2367) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Relax path validation \(\#2355\) [\#2356](https://github.com/apache/arrow-rs/pull/2356) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix Copy from percent-encoded path \(\#2353\) [\#2354](https://github.com/apache/arrow-rs/pull/2354) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add ObjectStore::get\_ranges \(\#2293\) [\#2336](https://github.com/apache/arrow-rs/pull/2336) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Remove vestigal ` object_store/.circleci/` [\#2337](https://github.com/apache/arrow-rs/pull/2337) ([alamb](https://github.com/alamb)) +- Handle symlinks in LocalFileSystem \(\#2206\) [\#2269](https://github.com/apache/arrow-rs/pull/2269) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Retry GCP requests on server error [\#2243](https://github.com/apache/arrow-rs/pull/2243) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add LimitStore \(\#2175\) [\#2242](https://github.com/apache/arrow-rs/pull/2242) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Only trigger `arrow` CI on changes to arrow [\#2227](https://github.com/apache/arrow-rs/pull/2227) ([alamb](https://github.com/alamb)) +- Update instructions on how to join the Slack channel [\#2219](https://github.com/apache/arrow-rs/pull/2219) ([HaoYang670](https://github.com/HaoYang670)) +- Add Builder style config objects for object\_store [\#2204](https://github.com/apache/arrow-rs/pull/2204) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Ignore broken symlinks for LocalFileSystem object store [\#2195](https://github.com/apache/arrow-rs/pull/2195) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jccampagne](https://github.com/jccampagne)) +- Change CI names to match crate names [\#2189](https://github.com/apache/arrow-rs/pull/2189) ([alamb](https://github.com/alamb)) +- Split most arrow specific CI checks into their own workflows \(reduce common CI time to 21 minutes\) [\#2168](https://github.com/apache/arrow-rs/pull/2168) ([alamb](https://github.com/alamb)) +- Remove another attempt to cache target directory in action.yaml [\#2167](https://github.com/apache/arrow-rs/pull/2167) ([alamb](https://github.com/alamb)) +- Run actions on push to master, pull requests [\#2166](https://github.com/apache/arrow-rs/pull/2166) ([alamb](https://github.com/alamb)) +- Break parquet\_derive and arrow\_flight tests into their own workflows [\#2165](https://github.com/apache/arrow-rs/pull/2165) ([alamb](https://github.com/alamb)) +- Only run integration tests when `arrow` changes [\#2152](https://github.com/apache/arrow-rs/pull/2152) ([alamb](https://github.com/alamb)) +- Break out docs CI job to its own github action [\#2151](https://github.com/apache/arrow-rs/pull/2151) ([alamb](https://github.com/alamb)) +- Do not pretend to cache rust build artifacts, speed up CI by ~20% [\#2150](https://github.com/apache/arrow-rs/pull/2150) ([alamb](https://github.com/alamb)) +- Port `object_store` integration tests, use github actions [\#2148](https://github.com/apache/arrow-rs/pull/2148) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Port Add stream upload \(multi-part upload\) [\#2147](https://github.com/apache/arrow-rs/pull/2147) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Increase upper wait time to reduce flakiness of object store test [\#2142](https://github.com/apache/arrow-rs/pull/2142) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([viirya](https://github.com/viirya)) + +\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..6dd9d0f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,72 @@ + + +# Changelog + +## [object_store_0.12.0](https://github.com/apache/arrow-rs/tree/object_store_0.12.0) (2025-03-05) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.11.2...object_store_0.12.0) + +**Breaking changes:** + +- feat: add `Extensions` to object store `PutMultipartOpts` [\#7214](https://github.com/apache/arrow-rs/pull/7214) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- feat: add `Extensions` to object store `PutOptions` [\#7213](https://github.com/apache/arrow-rs/pull/7213) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- chore: enable conditional put by default for S3 [\#7181](https://github.com/apache/arrow-rs/pull/7181) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([meteorgan](https://github.com/meteorgan)) +- feat: add `Extensions` to object store `GetOptions` [\#7170](https://github.com/apache/arrow-rs/pull/7170) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- feat\(object\_store\): Override DNS Resolution to Randomize IP Selection [\#7123](https://github.com/apache/arrow-rs/pull/7123) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- Use `u64` range instead of `usize`, for better wasm32 support [\#6961](https://github.com/apache/arrow-rs/pull/6961) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([XiangpengHao](https://github.com/XiangpengHao)) +- object\_store: Add enabled-by-default "fs" feature [\#6636](https://github.com/apache/arrow-rs/pull/6636) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- Return `BoxStream` with `'static` lifetime from `ObjectStore::list` [\#6619](https://github.com/apache/arrow-rs/pull/6619) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kylebarron](https://github.com/kylebarron)) +- object\_store: Migrate from snafu to thiserror [\#6266](https://github.com/apache/arrow-rs/pull/6266) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) + +**Implemented enhancements:** + +- Object Store: S3 IP address selection is biased [\#7117](https://github.com/apache/arrow-rs/issues/7117) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: GCSObjectStore should derive Clone [\#7113](https://github.com/apache/arrow-rs/issues/7113) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Remove all RCs after release [\#7059](https://github.com/apache/arrow-rs/issues/7059) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- LocalFileSystem::list\_with\_offset is very slow over network file system [\#7018](https://github.com/apache/arrow-rs/issues/7018) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Release object store `0.11.2` \(non API breaking\) Around Dec 15 2024 [\#6902](https://github.com/apache/arrow-rs/issues/6902) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- LocalFileSystem errors with satisfiable range request [\#6749](https://github.com/apache/arrow-rs/issues/6749) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- ObjectStore WASM32 Support [\#7226](https://github.com/apache/arrow-rs/pull/7226) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- \[main\] Bump arrow version to 54.2.1 \(\#7207\) [\#7212](https://github.com/apache/arrow-rs/pull/7212) ([alamb](https://github.com/alamb)) +- Decouple ObjectStore from Reqwest [\#7183](https://github.com/apache/arrow-rs/pull/7183) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- object\_store: Disable all compression formats in HTTP reqwest client [\#7143](https://github.com/apache/arrow-rs/pull/7143) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kylewlacy](https://github.com/kylewlacy)) +- refactor: remove unused `async` from `InMemory::entry` [\#7133](https://github.com/apache/arrow-rs/pull/7133) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([crepererum](https://github.com/crepererum)) +- object\_store/gcp: derive Clone for GoogleCloudStorage [\#7112](https://github.com/apache/arrow-rs/pull/7112) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([james-rms](https://github.com/james-rms)) +- Update version to 54.2.0 and add CHANGELOG [\#7110](https://github.com/apache/arrow-rs/pull/7110) ([alamb](https://github.com/alamb)) +- Remove all RCs after release [\#7060](https://github.com/apache/arrow-rs/pull/7060) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kou](https://github.com/kou)) +- Update release schedule README.md [\#7053](https://github.com/apache/arrow-rs/pull/7053) ([alamb](https://github.com/alamb)) +- Create GitHub releases automatically on tagging [\#7042](https://github.com/apache/arrow-rs/pull/7042) ([kou](https://github.com/kou)) +- Change Log On Succesful S3 Copy / Multipart Upload to Debug [\#7033](https://github.com/apache/arrow-rs/pull/7033) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([diptanu](https://github.com/diptanu)) +- Prepare for `54.1.0` release [\#7031](https://github.com/apache/arrow-rs/pull/7031) ([alamb](https://github.com/alamb)) +- Add a custom implementation `LocalFileSystem::list_with_offset` [\#7019](https://github.com/apache/arrow-rs/pull/7019) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([corwinjoy](https://github.com/corwinjoy)) +- Improve docs for `AmazonS3Builder::from_env` [\#6977](https://github.com/apache/arrow-rs/pull/6977) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kylebarron](https://github.com/kylebarron)) +- Fix WASM CI for Rust 1.84 release [\#6963](https://github.com/apache/arrow-rs/pull/6963) ([alamb](https://github.com/alamb)) +- Update itertools requirement from 0.13.0 to 0.14.0 in /object\_store [\#6925](https://github.com/apache/arrow-rs/pull/6925) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix LocalFileSystem with range request that ends beyond end of file [\#6751](https://github.com/apache/arrow-rs/pull/6751) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([kylebarron](https://github.com/kylebarron)) + + + +\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..5444ec7 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,202 @@ + + +# Development instructions + +## Running Tests + +Tests can be run using `cargo` + +```shell +cargo test +``` + +## Running Integration Tests + +By default, integration tests are not run. To run them you will need to set `TEST_INTEGRATION=1` and then provide the +necessary configuration for that object store + +### AWS + +To test the S3 integration against [localstack](https://localstack.cloud/) + +First start up a container running localstack + +``` +$ LOCALSTACK_VERSION=sha256:a0b79cb2430f1818de2c66ce89d41bba40f5a1823410f5a7eaf3494b692eed97 +$ podman run -d -p 4566:4566 localstack/localstack@$LOCALSTACK_VERSION +$ podman run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2 +``` + +Setup environment + +``` +export TEST_INTEGRATION=1 +export AWS_DEFAULT_REGION=us-east-1 +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test +export AWS_ENDPOINT=http://localhost:4566 +export AWS_ALLOW_HTTP=true +export AWS_BUCKET_NAME=test-bucket +``` + +Create a bucket using the AWS CLI + +``` +podman run --net=host --env-host amazon/aws-cli --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket +``` + +Or directly with: + +``` +aws s3 mb s3://test-bucket --endpoint-url=http://localhost:4566 +aws --endpoint-url=http://localhost:4566 dynamodb create-table --table-name test-table --key-schema AttributeName=path,KeyType=HASH AttributeName=etag,KeyType=RANGE --attribute-definitions AttributeName=path,AttributeType=S AttributeName=etag,AttributeType=S --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5 +``` + +Run tests + +``` +$ cargo test --features aws +``` + +#### Encryption tests + +To create an encryption key for the tests, you can run the following command: + +``` +export AWS_SSE_KMS_KEY_ID=$(aws --endpoint-url=http://localhost:4566 \ + kms create-key --description "test key" | + jq -r '.KeyMetadata.KeyId') +``` + +To run integration tests with encryption, you can set the following environment variables: + +``` +export AWS_SERVER_SIDE_ENCRYPTION=aws:kms +export AWS_SSE_BUCKET_KEY=false +cargo test --features aws +``` + +As well as: + +``` +unset AWS_SSE_BUCKET_KEY +export AWS_SERVER_SIDE_ENCRYPTION=aws:kms:dsse +cargo test --features aws +``` + +#### SSE-C Encryption tests + +Unfortunately, localstack does not support SSE-C encryption (https://github.com/localstack/localstack/issues/11356). + +We will use [MinIO](https://min.io/docs/minio/container/operations/server-side-encryption.html) to test SSE-C encryption. + +First, create a self-signed certificate to enable HTTPS for MinIO, as SSE-C requires HTTPS. + +```shell +mkdir ~/certs +cd ~/certs +openssl genpkey -algorithm RSA -out private.key +openssl req -new -key private.key -out request.csr -subj "/C=US/ST=State/L=City/O=Organization/OU=Unit/CN=example.com/emailAddress=email@example.com" +openssl x509 -req -days 365 -in request.csr -signkey private.key -out public.crt +rm request.csr +``` + +Second, start MinIO with the self-signed certificate. + +```shell +docker run -d \ + -p 9000:9000 \ + --name minio \ + -v ${HOME}/certs:/root/.minio/certs \ + -e "MINIO_ROOT_USER=minio" \ + -e "MINIO_ROOT_PASSWORD=minio123" \ + minio/minio server /data +``` + +Create a test bucket. + +```shell +export AWS_BUCKET_NAME=test-bucket +export AWS_ACCESS_KEY_ID=minio +export AWS_SECRET_ACCESS_KEY=minio123 +export AWS_ENDPOINT=https://localhost:9000 +aws s3 mb s3://test-bucket --endpoint-url=https://localhost:9000 --no-verify-ssl +``` + +Run the tests. The real test is `test_s3_ssec_encryption_with_minio()` + +```shell +export TEST_S3_SSEC_ENCRYPTION=1 +cargo test --features aws --package object_store --lib aws::tests::test_s3_ssec_encryption_with_minio -- --exact --nocapture +``` + + + +### Azure + +To test the Azure integration +against [azurite](https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio) + +Startup azurite + +``` +$ podman run -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azure-storage/azurite +``` + +Create a bucket + +``` +$ podman run --net=host mcr.microsoft.com/azure-cli az storage container create -n test-bucket --connection-string 'DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;QueueEndpoint=http://127.0.0.1:10001/devstoreaccount1;' +``` + +Run tests + +```shell +AZURE_USE_EMULATOR=1 \ +TEST_INTEGRATION=1 \ +OBJECT_STORE_BUCKET=test-bucket \ +AZURE_STORAGE_ACCOUNT=devstoreaccount1 \ +AZURE_STORAGE_ACCESS_KEY=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw== \ +cargo test --features azure +``` + +### GCP + +To test the GCS integration, we use [Fake GCS Server](https://github.com/fsouza/fake-gcs-server) + +Startup the fake server: + +```shell +docker run -p 4443:4443 tustvold/fake-gcs-server -scheme http +``` + +Configure the account: +```shell +curl -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "http://localhost:4443/storage/v1/b" +echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > /tmp/gcs.json +``` + +Now run the tests: +```shell +TEST_INTEGRATION=1 \ +OBJECT_STORE_BUCKET=test-bucket \ +GOOGLE_SERVICE_ACCOUNT=/tmp/gcs.json \ +cargo test -p object_store --features=gcp +``` diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8370cd5 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "object_store" +version = "0.12.0" +edition = "2021" +license = "MIT/Apache-2.0" +readme = "README.md" +description = "A generic object store interface for uniformly interacting with AWS S3, Google Cloud Storage, Azure Blob Storage and local files." +keywords = ["object", "storage", "cloud"] +repository = "https://github.com/apache/arrow-rs/tree/main/object_store" +rust-version = "1.64.0" + +[package.metadata.docs.rs] +all-features = true + +[dependencies] # In alphabetical order +async-trait = "0.1.53" +bytes = "1.0" +chrono = { version = "0.4.34", default-features = false, features = ["clock"] } +futures = "0.3" +http = "1.2.0" +humantime = "2.1" +itertools = "0.14.0" +parking_lot = { version = "0.12" } +percent-encoding = "2.1" +thiserror = "2.0.2" +tracing = { version = "0.1" } +url = "2.2" +walkdir = { version = "2", optional = true } + +# Cloud storage support +base64 = { version = "0.22", default-features = false, features = ["std"], optional = true } +form_urlencoded = { version = "1.2", optional = true } +http-body-util = { version = "0.1.2", optional = true } +httparse = { version = "1.8.0", default-features = false, features = ["std"], optional = true } +hyper = { version = "1.2", default-features = false, optional = true } +md-5 = { version = "0.10.6", default-features = false, optional = true } +quick-xml = { version = "0.37.0", features = ["serialize", "overlapped-lists"], optional = true } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2"], optional = true } +ring = { version = "0.17", default-features = false, features = ["std"], optional = true } +rustls-pemfile = { version = "2.0", default-features = false, features = ["std"], optional = true } +serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } +serde_json = { version = "1.0", default-features = false, features = ["std"], optional = true } +serde_urlencoded = { version = "0.7", optional = true } +tokio = { version = "1.29.0", features = ["sync", "macros", "rt", "time", "io-util"] } + +[target.'cfg(target_family="unix")'.dev-dependencies] +nix = { version = "0.29.0", features = ["fs"] } + +[features] +default = ["fs"] +cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/stream", "chrono/serde", "base64", "rand", "ring", "http-body-util", "form_urlencoded", "serde_urlencoded"] +azure = ["cloud", "httparse"] +fs = ["walkdir"] +gcp = ["cloud", "rustls-pemfile"] +aws = ["cloud", "md-5"] +http = ["cloud"] +tls-webpki-roots = ["reqwest?/rustls-tls-webpki-roots"] +integration = [] + +[dev-dependencies] # In alphabetical order +hyper = { version = "1.2", features = ["server"] } +hyper-util = "0.1" +rand = "0.8" +tempfile = "3.1.0" +regex = "1.11.1" +# The "gzip" feature for reqwest is enabled for an integration test. +reqwest = { version = "0.12", features = ["gzip"] } + +[[test]] +name = "get_range_file" +path = "tests/get_range_file.rs" diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..de4b130 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,204 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 0000000..0a23eee --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,5 @@ +Apache Arrow Object Store +Copyright 2020-2024 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md new file mode 100644 index 0000000..1799bf8 --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ + + +# Rust Object Store + +A focused, easy to use, idiomatic, high performance, `async` object +store library for interacting with object stores. + +Using this crate, the same binary and code can easily run in multiple +clouds and local test environments, via a simple runtime configuration +change. Supported object stores include: + +* [AWS S3](https://aws.amazon.com/s3/) +* [Azure Blob Storage](https://azure.microsoft.com/en-us/services/storage/blobs/) +* [Google Cloud Storage](https://cloud.google.com/storage) +* Local files +* Memory +* [HTTP/WebDAV Storage](https://datatracker.ietf.org/doc/html/rfc2518) +* Custom implementations + +Originally developed by [InfluxData](https://www.influxdata.com/) and later donated to [Apache Arrow](https://arrow.apache.org/). + +See [docs.rs](https://docs.rs/object_store) for usage instructions + +## Support for `wasm32-unknown-unknown` target + +It's possible to build `object_store` for the `wasm32-unknown-unknown` target, however the cloud storage features `aws`, `azure`, `gcp`, and `http` are not supported. + +``` +cargo build -p object_store --target wasm32-unknown-unknown +``` \ No newline at end of file diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..bfd060a --- /dev/null +++ b/deny.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Configuration documentation: +#  https://embarkstudios.github.io/cargo-deny/index.html + +[advisories] +vulnerability = "deny" +yanked = "deny" +unmaintained = "warn" +notice = "warn" +ignore = [ +] +git-fetch-with-cli = true + +[licenses] +default = "allow" +unlicensed = "allow" +copyleft = "allow" + +[bans] +multiple-versions = "warn" +deny = [ + # We are using rustls as the TLS implementation, so we shouldn't be linking + # in OpenSSL too. + # + # If you're hitting this, you might want to take a look at what new + # dependencies you have introduced and check if there's a way to depend on + # rustls instead of OpenSSL (tip: check the crate's feature flags). + { name = "openssl-sys" } +] diff --git a/dev/release/README.md b/dev/release/README.md new file mode 100644 index 0000000..2dd1f62 --- /dev/null +++ b/dev/release/README.md @@ -0,0 +1,228 @@ + + + +# Release Process + +## Overview + +This file documents the release process for the `object_store` crate. + +We release a new version of `object_store` according to the schedule listed in +the [main README.md] + +[main README.md]: https://github.com/apache/arrow-rs?tab=readme-ov-file#object_store-crate + +As we are still in an early phase, we use the 0.x version scheme. If any code has +been merged to main that has a breaking API change, as defined in [Rust RFC 1105] +the minor version number is incremented changed (e.g. `0.3.0` to `0.4.0`). +Otherwise the patch version is incremented (e.g. `0.3.0` to `0.3.1`). + +[Rust RFC 1105]: https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md +# Release Mechanics + +## Process Overview + +As part of the Apache governance model, official releases consist of +signed source tarballs approved by the PMC. + +We then use the code in the approved source tarball to release to +crates.io, the Rust ecosystem's package manager. + +We create a `CHANGELOG.md` so our users know what has been changed between releases. + +The CHANGELOG is created automatically using +[update_change_log.sh](https://github.com/apache/arrow-rs/blob/main/object_store/dev/release/update_change_log.sh) + +This script creates a changelog using github issues and the +labels associated with them. + +## Prepare CHANGELOG and version: + +Now prepare a PR to update `CHANGELOG.md` and versions on `main` to reflect the planned release. + +Note this process is done in the `object_store` directory. See [#6227] for an example + +[#6227]: https://github.com/apache/arrow-rs/pull/6227 + +```bash +# NOTE: Run commands in object_store sub directory (not main repo checkout) +# cd object_store + +git checkout main +git pull +git checkout -b + +# Update versions. Make sure to run it before the next step since we do not want CHANGELOG-old.md affected. +sed -i '' -e 's/0.11.0/0.11.1/g' `find . -name 'Cargo.toml' -or -name '*.md' | grep -v CHANGELOG` +git commit -a -m 'Update version' + +# ensure your github token is available +export CHANGELOG_GITHUB_TOKEN= + +# manually edit ./dev/release/update_change_log.sh to reflect the release version +# create the changelog +./dev/release/update_change_log.sh + +# review change log / and edit associated issues and labels if needed, rerun update_change_log.sh + +# Commit changes +git commit -a -m 'Create changelog' + +# push changes to fork and create a PR to main +git push +``` + +Note that when reviewing the change log, rather than editing the +`CHANGELOG.md`, it is preferred to update the issues and their labels +(e.g. add `invalid` label to exclude them from release notes) + +Merge this PR to `main` prior to the next step. + +## Prepare release candidate tarball + +After you have merged the updates to the `CHANGELOG` and version, +create a release candidate using the following steps. Note you need to +be a committer to run these scripts as they upload to the apache `svn` +distribution servers. + +### Create git tag for the release: + +While the official release artifact is a signed tarball, we also tag the commit it was created for convenience and code archaeology. + +For `object_store` releases, use a string such as `object_store_0.4.0` as the ``. + +Create and push the tag thusly: + +```shell +git fetch apache +git tag apache/main +# push tag to apache +git push apache +``` + +### Pick an Release Candidate (RC) number + +Pick numbers in sequential order, with `1` for `rc1`, `2` for `rc2`, etc. + +### Create, sign, and upload tarball + +Run `create-tarball.sh` with the `` tag and `` and you found in previous steps. + +```shell +./object_store/dev/release/create-tarball.sh 0.11.1 1 +``` + +The `create-tarball.sh` script + +1. creates and uploads a release candidate tarball to the [arrow + dev](https://dist.apache.org/repos/dist/dev/arrow) location on the + apache distribution svn server + +2. provide you an email template to + send to dev@arrow.apache.org for release voting. + +### Vote on Release Candidate tarball + +Send an email, based on the output from the script to dev@arrow.apache.org. The email should look like + +``` +Draft email for dev@arrow.apache.org mailing list + +--------------------------------------------------------- +To: dev@arrow.apache.org +Subject: [VOTE][RUST] Release Apache Arrow Rust Object Store 0.11.1 RC1 + +Hi, + +I would like to propose a release of Apache Arrow Rust Object +Store Implementation, version 0.11.1. + +This release candidate is based on commit: b945b15de9085f5961a478d4f35b0c5c3427e248 [1] + +The proposed release tarball and signatures are hosted at [2]. + +The changelog is located at [3]. + +Please download, verify checksums and signatures, run the unit tests, +and vote on the release. There is a script [4] that automates some of +the verification. + +The vote will be open for at least 72 hours. + +[ ] +1 Release this as Apache Arrow Rust Object Store +[ ] +0 +[ ] -1 Do not release this as Apache Arrow Rust Object Store because... + +[1]: https://github.com/apache/arrow-rs/tree/b945b15de9085f5961a478d4f35b0c5c3427e248 +[2]: https://dist.apache.org/repos/dist/dev/arrow/apache-arrow-object-store-rs-0.11.1-rc1 +[3]: https://github.com/apache/arrow-rs/blob/b945b15de9085f5961a478d4f35b0c5c3427e248/object_store/CHANGELOG.md +[4]: https://github.com/apache/arrow-rs/blob/main/object_store/dev/release/verify-release-candidate.sh +``` + +For the release to become "official" it needs at least three Apache Arrow PMC members to vote +1 on it. + +## Verifying release candidates + +The `object_store/dev/release/verify-release-candidate.sh` script can assist in the verification process. Run it like: + +``` +./object_store/dev/release/verify-release-candidate.sh 0.11.0 1 +``` + +#### If the release is not approved + +If the release is not approved, fix whatever the problem is and try again with the next RC number + +### If the release is approved, + +Move tarball to the release location in SVN, e.g. https://dist.apache.org/repos/dist/release/arrow/arrow-4.1.0/, using the `release-tarball.sh` script: + + +```shell +./object_store/dev/release/release-tarball.sh 4.1.0 2 +``` + +Congratulations! The release is now official! + +### Publish on Crates.io + +Only approved releases of the tarball should be published to +crates.io, in order to conform to Apache Software Foundation +governance standards. + +An Arrow committer can publish this crate after an official project release has +been made to crates.io using the following instructions. + +Follow [these +instructions](https://doc.rust-lang.org/cargo/reference/publishing.html) to +create an account and login to crates.io before asking to be added as an owner +of the [arrow crate](https://crates.io/crates/arrow). + +Download and unpack the official release tarball + +Verify that the Cargo.toml in the tarball contains the correct version +(e.g. `version = "0.11.0"`) and then publish the crate with the +following commands + + +```shell +cargo publish +``` + diff --git a/dev/release/create-tarball.sh b/dev/release/create-tarball.sh new file mode 100755 index 0000000..efc26fd --- /dev/null +++ b/dev/release/create-tarball.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This script creates a signed tarball in +# dev/dist/apache-arrow-object-store-rs--.tar.gz and uploads it to +# the "dev" area of the dist.apache.arrow repository and prepares an +# email for sending to the dev@arrow.apache.org list for a formal +# vote. +# +# Note the tags are expected to be `object_sore_` +# +# See release/README.md for full release instructions +# +# Requirements: +# +# 1. gpg setup for signing and have uploaded your public +# signature to https://pgp.mit.edu/ +# +# 2. Logged into the apache svn server with the appropriate +# credentials +# +# +# Based in part on 02-source.sh from apache/arrow +# + +set -e + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "ex. $0 0.4.0 1" + exit +fi + +object_store_version=$1 +rc=$2 + +tag=object_store_${object_store_version} + +release=apache-arrow-object-store-rs-${object_store_version} +distdir=${SOURCE_TOP_DIR}/dev/dist/${release}-rc${rc} +tarname=${release}.tar.gz +tarball=${distdir}/${tarname} +url="https://dist.apache.org/repos/dist/dev/arrow/${release}-rc${rc}" + +echo "Attempting to create ${tarball} from tag ${tag}" + +release_hash=$(cd "${SOURCE_TOP_DIR}" && git rev-list --max-count=1 ${tag}) + +if [ -z "$release_hash" ]; then + echo "Cannot continue: unknown git tag: $tag" +fi + +echo "Draft email for dev@arrow.apache.org mailing list" +echo "" +echo "---------------------------------------------------------" +cat < containing the files in git at $release_hash +# the files in the tarball are prefixed with {object_store_version=} (e.g. 0.4.0) +mkdir -p ${distdir} +(cd "${SOURCE_TOP_DIR}" && git archive ${release_hash} --prefix ${release}/ | gzip > ${tarball}) + +echo "Running rat license checker on ${tarball}" +${SOURCE_DIR}/../../../dev/release/run-rat.sh ${tarball} + +echo "Signing tarball and creating checksums" +gpg --armor --output ${tarball}.asc --detach-sig ${tarball} +# create signing with relative path of tarball +# so that they can be verified with a command such as +# shasum --check apache-arrow-rs-4.1.0-rc2.tar.gz.sha512 +(cd ${distdir} && shasum -a 256 ${tarname}) > ${tarball}.sha256 +(cd ${distdir} && shasum -a 512 ${tarname}) > ${tarball}.sha512 + +echo "Uploading to apache dist/dev to ${url}" +svn co --depth=empty https://dist.apache.org/repos/dist/dev/arrow ${SOURCE_TOP_DIR}/dev/dist +svn add ${distdir} +svn ci -m "Apache Arrow Rust ${object_store_version=} ${rc}" ${distdir} diff --git a/dev/release/release-tarball.sh b/dev/release/release-tarball.sh new file mode 100755 index 0000000..16b10e0 --- /dev/null +++ b/dev/release/release-tarball.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This script copies a tarball from the "dev" area of the +# dist.apache.arrow repository to the "release" area +# +# This script should only be run after the release has been approved +# by the arrow PMC committee. +# +# See release/README.md for full release instructions +# +# Based in part on post-01-upload.sh from apache/arrow + + +set -e +set -u + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "ex. $0 0.4.0 1" + exit +fi + +version=$1 +rc=$2 + +tmp_dir=tmp-apache-arrow-dist + +echo "Recreate temporary directory: ${tmp_dir}" +rm -rf ${tmp_dir} +mkdir -p ${tmp_dir} + +echo "Clone dev dist repository" +svn \ + co \ + https://dist.apache.org/repos/dist/dev/arrow/apache-arrow-object-store-rs-${version}-rc${rc} \ + ${tmp_dir}/dev + +echo "Clone release dist repository" +svn co https://dist.apache.org/repos/dist/release/arrow ${tmp_dir}/release + +echo "Copy ${version}-rc${rc} to release working copy" +release_version=arrow-object-store-rs-${version} +mkdir -p ${tmp_dir}/release/${release_version} +cp -r ${tmp_dir}/dev/* ${tmp_dir}/release/${release_version}/ +svn add ${tmp_dir}/release/${release_version} + +echo "Commit release" +svn ci -m "Apache Arrow Rust Object Store ${version}" ${tmp_dir}/release + +echo "Clean up" +rm -rf ${tmp_dir} + +echo "Success!" +echo "The release is available here:" +echo " https://dist.apache.org/repos/dist/release/arrow/${release_version}" + +echo "Clean up old artifacts from svn" +"${SOURCE_TOP_DIR}"/dev/release/remove-old-artifacts.sh diff --git a/dev/release/remove-old-artifacts.sh b/dev/release/remove-old-artifacts.sh new file mode 100755 index 0000000..bbbbe0c --- /dev/null +++ b/dev/release/remove-old-artifacts.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This script removes all RCs and all but the most recent versions of +# object_store from svn. +# +# The older versions are in SVN history as well as available on the +# archive page https://archive.apache.org/dist/ +# +# See +# https://infra.apache.org/release-download-pages.html + +set -e +set -u +set -o pipefail + +echo "Remove all RCs" +dev_base_url=https://dist.apache.org/repos/dist/dev/arrow +old_rcs=$( + svn ls ${dev_base_url}/ | \ + grep -E '^apache-arrow-object-store-rs-[0-9]' | \ + sort --version-sort +) +for old_rc in $old_rcs; do + echo "Remove RC: ${old_rc}" + svn \ + delete \ + -m "Remove old Apache Arrow Rust Object Store RC: ${old_rc}" \ + ${dev_base_url}/${old_rc} +done + +echo "Remove all but the most recent version" +release_base_url="https://dist.apache.org/repos/dist/release/arrow" +old_releases=$( + svn ls ${release_base_url} | \ + grep -E '^arrow-object-store-rs-[0-9\.]+' | \ + sort --version-sort --reverse | \ + tail -n +2 +) +for old_release_version in $old_releases; do + echo "Remove old release: ${old_release_version}" + svn \ + delete \ + -m "Remove Apache Arrow Rust Object Store release: ${old_release_version}" \ + ${release_base_url}/${old_release_version} +done diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh new file mode 100755 index 0000000..f52c9f4 --- /dev/null +++ b/dev/release/update_change_log.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# invokes the changelog generator from +# https://github.com/github-changelog-generator/github-changelog-generator +# +# With the config located in +# arrow-rs/object_store/.github_changelog_generator +# +# Usage: +# CHANGELOG_GITHUB_TOKEN= ./update_change_log.sh + +set -e + +SINCE_TAG="object_store_0.11.2" +FUTURE_RELEASE="object_store_0.12.0" + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + +OUTPUT_PATH="${SOURCE_TOP_DIR}/CHANGELOG.md" + +# remove license header so github-changelog-generator has a clean base to append +sed -i.bak '1,18d' "${OUTPUT_PATH}" + +# use exclude-tags-regex to filter out tags used for arrow +# crates and only look at tags that begin with `object_store_` +pushd "${SOURCE_TOP_DIR}" +docker run -it --rm -e CHANGELOG_GITHUB_TOKEN="$CHANGELOG_GITHUB_TOKEN" -v "$(pwd)":/usr/local/src/your-app githubchangeloggenerator/github-changelog-generator \ + --user apache \ + --project arrow-rs \ + --cache-file=.githubchangeloggenerator.cache \ + --cache-log=.githubchangeloggenerator.cache.log \ + --http-cache \ + --max-issues=600 \ + --include-labels="object-store" \ + --exclude-tags-regex "(^\d+\.\d+\.\d+$)|(rc)" \ + --since-tag ${SINCE_TAG} \ + --future-release ${FUTURE_RELEASE} + +sed -i.bak "s/\\\n/\n\n/" "${OUTPUT_PATH}" + +# Put license header back on +echo ' +' | cat - "${OUTPUT_PATH}" > "${OUTPUT_PATH}".tmp +mv "${OUTPUT_PATH}".tmp "${OUTPUT_PATH}" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh new file mode 100755 index 0000000..b24bd8f --- /dev/null +++ b/dev/release/verify-release-candidate.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +case $# in + 2) VERSION="$1" + RC_NUMBER="$2" + ;; + *) echo "Usage: $0 X.Y.Z RC_NUMBER" + exit 1 + ;; +esac + +set -e +set -x +set -o pipefail + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" +ARROW_DIR="$(dirname $(dirname ${SOURCE_DIR}))" +ARROW_DIST_URL='https://dist.apache.org/repos/dist/dev/arrow' + +download_dist_file() { + curl \ + --silent \ + --show-error \ + --fail \ + --location \ + --remote-name $ARROW_DIST_URL/$1 +} + +download_rc_file() { + download_dist_file apache-arrow-object-store-rs-${VERSION}-rc${RC_NUMBER}/$1 +} + +import_gpg_keys() { + download_dist_file KEYS + gpg --import KEYS +} + +if type shasum >/dev/null 2>&1; then + sha256_verify="shasum -a 256 -c" + sha512_verify="shasum -a 512 -c" +else + sha256_verify="sha256sum -c" + sha512_verify="sha512sum -c" +fi + +fetch_archive() { + local dist_name=$1 + download_rc_file ${dist_name}.tar.gz + download_rc_file ${dist_name}.tar.gz.asc + download_rc_file ${dist_name}.tar.gz.sha256 + download_rc_file ${dist_name}.tar.gz.sha512 + gpg --verify ${dist_name}.tar.gz.asc ${dist_name}.tar.gz + ${sha256_verify} ${dist_name}.tar.gz.sha256 + ${sha512_verify} ${dist_name}.tar.gz.sha512 +} + +setup_tempdir() { + cleanup() { + if [ "${TEST_SUCCESS}" = "yes" ]; then + rm -fr "${ARROW_TMPDIR}" + else + echo "Failed to verify release candidate. See ${ARROW_TMPDIR} for details." + fi + } + + if [ -z "${ARROW_TMPDIR}" ]; then + # clean up automatically if ARROW_TMPDIR is not defined + ARROW_TMPDIR=$(mktemp -d -t "$1.XXXXX") + trap cleanup EXIT + else + # don't clean up automatically + mkdir -p "${ARROW_TMPDIR}" + fi +} + +test_source_distribution() { + # install rust toolchain in a similar fashion like test-miniconda + export RUSTUP_HOME=$PWD/test-rustup + export CARGO_HOME=$PWD/test-rustup + + curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path + + export PATH=$RUSTUP_HOME/bin:$PATH + source $RUSTUP_HOME/env + + # build and test rust + cargo build + cargo test --all --all-features + + # verify that the crate can be published to crates.io + cargo publish --dry-run +} + +TEST_SUCCESS=no + +setup_tempdir "arrow-${VERSION}" +echo "Working in sandbox ${ARROW_TMPDIR}" +cd ${ARROW_TMPDIR} + +dist_name="apache-arrow-object-store-rs-${VERSION}" +import_gpg_keys +fetch_archive ${dist_name} +tar xf ${dist_name}.tar.gz +pushd ${dist_name} +test_source_distribution +popd + +TEST_SUCCESS=yes +echo 'Release candidate looks good!' +exit 0 diff --git a/src/attributes.rs b/src/attributes.rs new file mode 100644 index 0000000..11cf27c --- /dev/null +++ b/src/attributes.rs @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::borrow::Cow; +use std::collections::HashMap; +use std::ops::Deref; + +/// Additional object attribute types +#[non_exhaustive] +#[derive(Debug, Hash, Eq, PartialEq, Clone)] +pub enum Attribute { + /// Specifies how the object should be handled by a browser + /// + /// See [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) + ContentDisposition, + /// Specifies the encodings applied to the object + /// + /// See [Content-Encoding](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding) + ContentEncoding, + /// Specifies the language of the object + /// + /// See [Content-Language](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Language) + ContentLanguage, + /// Specifies the MIME type of the object + /// + /// This takes precedence over any [ClientOptions](crate::ClientOptions) configuration + /// + /// See [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) + ContentType, + /// Overrides cache control policy of the object + /// + /// See [Cache-Control](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control) + CacheControl, + /// Specifies a user-defined metadata field for the object + /// + /// The String is a user-defined key + Metadata(Cow<'static, str>), +} + +/// The value of an [`Attribute`] +/// +/// Provides efficient conversion from both static and owned strings +/// +/// ``` +/// # use object_store::AttributeValue; +/// // Can use static strings without needing an allocation +/// let value = AttributeValue::from("bar"); +/// // Can also store owned strings +/// let value = AttributeValue::from("foo".to_string()); +/// ``` +#[derive(Debug, Hash, Eq, PartialEq, Clone)] +pub struct AttributeValue(Cow<'static, str>); + +impl AsRef for AttributeValue { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl From<&'static str> for AttributeValue { + fn from(value: &'static str) -> Self { + Self(Cow::Borrowed(value)) + } +} + +impl From for AttributeValue { + fn from(value: String) -> Self { + Self(Cow::Owned(value)) + } +} + +impl Deref for AttributeValue { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +/// Additional attributes of an object +/// +/// Attributes can be specified in [PutOptions](crate::PutOptions) and retrieved +/// from APIs returning [GetResult](crate::GetResult). +/// +/// Unlike [`ObjectMeta`](crate::ObjectMeta), [`Attributes`] are not returned by +/// listing APIs +#[derive(Debug, Default, Eq, PartialEq, Clone)] +pub struct Attributes(HashMap); + +impl Attributes { + /// Create a new empty [`Attributes`] + pub fn new() -> Self { + Self::default() + } + + /// Create a new [`Attributes`] with space for `capacity` [`Attribute`] + pub fn with_capacity(capacity: usize) -> Self { + Self(HashMap::with_capacity(capacity)) + } + + /// Insert a new [`Attribute`], [`AttributeValue`] pair + /// + /// Returns the previous value for `key` if any + pub fn insert(&mut self, key: Attribute, value: AttributeValue) -> Option { + self.0.insert(key, value) + } + + /// Returns the [`AttributeValue`] for `key` if any + pub fn get(&self, key: &Attribute) -> Option<&AttributeValue> { + self.0.get(key) + } + + /// Removes the [`AttributeValue`] for `key` if any + pub fn remove(&mut self, key: &Attribute) -> Option { + self.0.remove(key) + } + + /// Returns an [`AttributesIter`] over this + pub fn iter(&self) -> AttributesIter<'_> { + self.into_iter() + } + + /// Returns the number of [`Attribute`] in this collection + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if this contains no [`Attribute`] + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl FromIterator<(K, V)> for Attributes +where + K: Into, + V: Into, +{ + fn from_iter>(iter: T) -> Self { + Self( + iter.into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), + ) + } +} + +impl<'a> IntoIterator for &'a Attributes { + type Item = (&'a Attribute, &'a AttributeValue); + type IntoIter = AttributesIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + AttributesIter(self.0.iter()) + } +} + +/// Iterator over [`Attributes`] +#[derive(Debug)] +pub struct AttributesIter<'a>(std::collections::hash_map::Iter<'a, Attribute, AttributeValue>); + +impl<'a> Iterator for AttributesIter<'a> { + type Item = (&'a Attribute, &'a AttributeValue); + + fn next(&mut self) -> Option { + self.0.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_attributes_basic() { + let mut attributes = Attributes::from_iter([ + (Attribute::ContentDisposition, "inline"), + (Attribute::ContentEncoding, "gzip"), + (Attribute::ContentLanguage, "en-US"), + (Attribute::ContentType, "test"), + (Attribute::CacheControl, "control"), + (Attribute::Metadata("key1".into()), "value1"), + ]); + + assert!(!attributes.is_empty()); + assert_eq!(attributes.len(), 6); + + assert_eq!( + attributes.get(&Attribute::ContentType), + Some(&"test".into()) + ); + + let metav = "control".into(); + assert_eq!(attributes.get(&Attribute::CacheControl), Some(&metav)); + assert_eq!( + attributes.insert(Attribute::CacheControl, "v1".into()), + Some(metav) + ); + assert_eq!(attributes.len(), 6); + + assert_eq!( + attributes.remove(&Attribute::CacheControl).unwrap(), + "v1".into() + ); + assert_eq!(attributes.len(), 5); + + let metav: AttributeValue = "v2".into(); + attributes.insert(Attribute::CacheControl, metav.clone()); + assert_eq!(attributes.get(&Attribute::CacheControl), Some(&metav)); + assert_eq!(attributes.len(), 6); + + assert_eq!( + attributes.get(&Attribute::ContentDisposition), + Some(&"inline".into()) + ); + assert_eq!( + attributes.get(&Attribute::ContentEncoding), + Some(&"gzip".into()) + ); + assert_eq!( + attributes.get(&Attribute::ContentLanguage), + Some(&"en-US".into()) + ); + assert_eq!( + attributes.get(&Attribute::Metadata("key1".into())), + Some(&"value1".into()) + ); + } +} diff --git a/src/aws/builder.rs b/src/aws/builder.rs new file mode 100644 index 0000000..5dff94d --- /dev/null +++ b/src/aws/builder.rs @@ -0,0 +1,1544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::client::{S3Client, S3Config}; +use crate::aws::credential::{ + InstanceCredentialProvider, SessionProvider, TaskCredentialProvider, WebIdentityProvider, +}; +use crate::aws::{ + AmazonS3, AwsCredential, AwsCredentialProvider, Checksum, S3ConditionalPut, S3CopyIfNotExists, + STORE, +}; +use crate::client::{http_connector, HttpConnector, TokenCredentialProvider}; +use crate::config::ConfigValue; +use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider}; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use itertools::Itertools; +use md5::{Digest, Md5}; +use reqwest::header::{HeaderMap, HeaderValue}; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use tracing::info; +use url::Url; + +/// Default metadata endpoint +static DEFAULT_METADATA_ENDPOINT: &str = "http://169.254.169.254"; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Missing bucket name")] + MissingBucketName, + + #[error("Missing AccessKeyId")] + MissingAccessKeyId, + + #[error("Missing SecretAccessKey")] + MissingSecretAccessKey, + + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[error( + "Unknown url scheme cannot be parsed into storage location: {}", + scheme + )] + UnknownUrlScheme { scheme: String }, + + #[error("URL did not match any known pattern for scheme: {}", url)] + UrlNotRecognised { url: String }, + + #[error("Configuration key: '{}' is not known.", key)] + UnknownConfigurationKey { key: String }, + + #[error("Invalid Zone suffix for bucket '{bucket}'")] + ZoneSuffix { bucket: String }, + + #[error("Invalid encryption type: {}. Valid values are \"AES256\", \"sse:kms\", \"sse:kms:dsse\" and \"sse-c\".", passed)] + InvalidEncryptionType { passed: String }, + + #[error( + "Invalid encryption header values. Header: {}, source: {}", + header, + source + )] + InvalidEncryptionHeader { + header: &'static str, + source: Box, + }, +} + +impl From for crate::Error { + fn from(source: Error) -> Self { + match source { + Error::UnknownConfigurationKey { key } => { + Self::UnknownConfigurationKey { store: STORE, key } + } + _ => Self::Generic { + store: STORE, + source: Box::new(source), + }, + } + } +} + +/// Configure a connection to Amazon S3 using the specified credentials in +/// the specified Amazon region and bucket. +/// +/// # Example +/// ``` +/// # let REGION = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY_ID = "foo"; +/// # let SECRET_KEY = "foo"; +/// # use object_store::aws::AmazonS3Builder; +/// let s3 = AmazonS3Builder::new() +/// .with_region(REGION) +/// .with_bucket_name(BUCKET_NAME) +/// .with_access_key_id(ACCESS_KEY_ID) +/// .with_secret_access_key(SECRET_KEY) +/// .build(); +/// ``` +#[derive(Debug, Default, Clone)] +pub struct AmazonS3Builder { + /// Access key id + access_key_id: Option, + /// Secret access_key + secret_access_key: Option, + /// Region + region: Option, + /// Bucket name + bucket_name: Option, + /// Endpoint for communicating with AWS S3 + endpoint: Option, + /// Token to use for requests + token: Option, + /// Url + url: Option, + /// Retry config + retry_config: RetryConfig, + /// When set to true, fallback to IMDSv1 + imdsv1_fallback: ConfigValue, + /// When set to true, virtual hosted style request has to be used + virtual_hosted_style_request: ConfigValue, + /// When set to true, S3 express is used + s3_express: ConfigValue, + /// When set to true, unsigned payload option has to be used + unsigned_payload: ConfigValue, + /// Checksum algorithm which has to be used for object integrity check during upload + checksum_algorithm: Option>, + /// Metadata endpoint, see + metadata_endpoint: Option, + /// Container credentials URL, see + container_credentials_relative_uri: Option, + /// Client options + client_options: ClientOptions, + /// Credentials + credentials: Option, + /// Skip signing requests + skip_signature: ConfigValue, + /// Copy if not exists + copy_if_not_exists: Option>, + /// Put precondition + conditional_put: ConfigValue, + /// Ignore tags + disable_tagging: ConfigValue, + /// Encryption (See [`S3EncryptionConfigKey`]) + encryption_type: Option>, + encryption_kms_key_id: Option, + encryption_bucket_key_enabled: Option>, + /// base64-encoded 256-bit customer encryption key for SSE-C. + encryption_customer_key_base64: Option, + /// When set to true, charge requester for bucket operations + request_payer: ConfigValue, + /// The [`HttpConnector`] to use + http_connector: Option>, +} + +/// Configuration keys for [`AmazonS3Builder`] +/// +/// Configuration via keys can be done via [`AmazonS3Builder::with_config`] +/// +/// # Example +/// ``` +/// # use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey}; +/// let builder = AmazonS3Builder::new() +/// .with_config("aws_access_key_id".parse().unwrap(), "my-access-key-id") +/// .with_config(AmazonS3ConfigKey::DefaultRegion, "my-default-region"); +/// ``` +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] +#[non_exhaustive] +pub enum AmazonS3ConfigKey { + /// AWS Access Key + /// + /// See [`AmazonS3Builder::with_access_key_id`] for details. + /// + /// Supported keys: + /// - `aws_access_key_id` + /// - `access_key_id` + AccessKeyId, + + /// Secret Access Key + /// + /// See [`AmazonS3Builder::with_secret_access_key`] for details. + /// + /// Supported keys: + /// - `aws_secret_access_key` + /// - `secret_access_key` + SecretAccessKey, + + /// Region + /// + /// See [`AmazonS3Builder::with_region`] for details. + /// + /// Supported keys: + /// - `aws_region` + /// - `region` + Region, + + /// Default region + /// + /// See [`AmazonS3Builder::with_region`] for details. + /// + /// Supported keys: + /// - `aws_default_region` + /// - `default_region` + DefaultRegion, + + /// Bucket name + /// + /// See [`AmazonS3Builder::with_bucket_name`] for details. + /// + /// Supported keys: + /// - `aws_bucket` + /// - `aws_bucket_name` + /// - `bucket` + /// - `bucket_name` + Bucket, + + /// Sets custom endpoint for communicating with AWS S3. + /// + /// See [`AmazonS3Builder::with_endpoint`] for details. + /// + /// Supported keys: + /// - `aws_endpoint` + /// - `aws_endpoint_url` + /// - `endpoint` + /// - `endpoint_url` + Endpoint, + + /// Token to use for requests (passed to underlying provider) + /// + /// See [`AmazonS3Builder::with_token`] for details. + /// + /// Supported keys: + /// - `aws_session_token` + /// - `aws_token` + /// - `session_token` + /// - `token` + Token, + + /// Fall back to ImdsV1 + /// + /// See [`AmazonS3Builder::with_imdsv1_fallback`] for details. + /// + /// Supported keys: + /// - `aws_imdsv1_fallback` + /// - `imdsv1_fallback` + ImdsV1Fallback, + + /// If virtual hosted style request has to be used + /// + /// See [`AmazonS3Builder::with_virtual_hosted_style_request`] for details. + /// + /// Supported keys: + /// - `aws_virtual_hosted_style_request` + /// - `virtual_hosted_style_request` + VirtualHostedStyleRequest, + + /// Avoid computing payload checksum when calculating signature. + /// + /// See [`AmazonS3Builder::with_unsigned_payload`] for details. + /// + /// Supported keys: + /// - `aws_unsigned_payload` + /// - `unsigned_payload` + UnsignedPayload, + + /// Set the checksum algorithm for this client + /// + /// See [`AmazonS3Builder::with_checksum_algorithm`] + Checksum, + + /// Set the instance metadata endpoint + /// + /// See [`AmazonS3Builder::with_metadata_endpoint`] for details. + /// + /// Supported keys: + /// - `aws_metadata_endpoint` + /// - `metadata_endpoint` + MetadataEndpoint, + + /// Set the container credentials relative URI + /// + /// + ContainerCredentialsRelativeUri, + + /// Configure how to provide `copy_if_not_exists` + /// + /// See [`S3CopyIfNotExists`] + CopyIfNotExists, + + /// Configure how to provide conditional put operations + /// + /// See [`S3ConditionalPut`] + ConditionalPut, + + /// Skip signing request + SkipSignature, + + /// Disable tagging objects + /// + /// This can be desirable if not supported by the backing store + /// + /// Supported keys: + /// - `aws_disable_tagging` + /// - `disable_tagging` + DisableTagging, + + /// Enable Support for S3 Express One Zone + /// + /// Supported keys: + /// - `aws_s3_express` + /// - `s3_express` + S3Express, + + /// Enable Support for S3 Requester Pays + /// + /// Supported keys: + /// - `aws_request_payer` + /// - `request_payer` + RequestPayer, + + /// Client options + Client(ClientConfigKey), + + /// Encryption options + Encryption(S3EncryptionConfigKey), +} + +impl AsRef for AmazonS3ConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::AccessKeyId => "aws_access_key_id", + Self::SecretAccessKey => "aws_secret_access_key", + Self::Region => "aws_region", + Self::Bucket => "aws_bucket", + Self::Endpoint => "aws_endpoint", + Self::Token => "aws_session_token", + Self::ImdsV1Fallback => "aws_imdsv1_fallback", + Self::VirtualHostedStyleRequest => "aws_virtual_hosted_style_request", + Self::S3Express => "aws_s3_express", + Self::DefaultRegion => "aws_default_region", + Self::MetadataEndpoint => "aws_metadata_endpoint", + Self::UnsignedPayload => "aws_unsigned_payload", + Self::Checksum => "aws_checksum_algorithm", + Self::ContainerCredentialsRelativeUri => "aws_container_credentials_relative_uri", + Self::SkipSignature => "aws_skip_signature", + Self::CopyIfNotExists => "aws_copy_if_not_exists", + Self::ConditionalPut => "aws_conditional_put", + Self::DisableTagging => "aws_disable_tagging", + Self::RequestPayer => "aws_request_payer", + Self::Client(opt) => opt.as_ref(), + Self::Encryption(opt) => opt.as_ref(), + } + } +} + +impl FromStr for AmazonS3ConfigKey { + type Err = crate::Error; + + fn from_str(s: &str) -> Result { + match s { + "aws_access_key_id" | "access_key_id" => Ok(Self::AccessKeyId), + "aws_secret_access_key" | "secret_access_key" => Ok(Self::SecretAccessKey), + "aws_default_region" | "default_region" => Ok(Self::DefaultRegion), + "aws_region" | "region" => Ok(Self::Region), + "aws_bucket" | "aws_bucket_name" | "bucket_name" | "bucket" => Ok(Self::Bucket), + "aws_endpoint_url" | "aws_endpoint" | "endpoint_url" | "endpoint" => Ok(Self::Endpoint), + "aws_session_token" | "aws_token" | "session_token" | "token" => Ok(Self::Token), + "aws_virtual_hosted_style_request" | "virtual_hosted_style_request" => { + Ok(Self::VirtualHostedStyleRequest) + } + "aws_s3_express" | "s3_express" => Ok(Self::S3Express), + "aws_imdsv1_fallback" | "imdsv1_fallback" => Ok(Self::ImdsV1Fallback), + "aws_metadata_endpoint" | "metadata_endpoint" => Ok(Self::MetadataEndpoint), + "aws_unsigned_payload" | "unsigned_payload" => Ok(Self::UnsignedPayload), + "aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum), + "aws_container_credentials_relative_uri" => Ok(Self::ContainerCredentialsRelativeUri), + "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), + "aws_copy_if_not_exists" | "copy_if_not_exists" => Ok(Self::CopyIfNotExists), + "aws_conditional_put" | "conditional_put" => Ok(Self::ConditionalPut), + "aws_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging), + "aws_request_payer" | "request_payer" => Ok(Self::RequestPayer), + // Backwards compatibility + "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), + "aws_server_side_encryption" => Ok(Self::Encryption( + S3EncryptionConfigKey::ServerSideEncryption, + )), + "aws_sse_kms_key_id" => Ok(Self::Encryption(S3EncryptionConfigKey::KmsKeyId)), + "aws_sse_bucket_key_enabled" => { + Ok(Self::Encryption(S3EncryptionConfigKey::BucketKeyEnabled)) + } + "aws_sse_customer_key_base64" => Ok(Self::Encryption( + S3EncryptionConfigKey::CustomerEncryptionKey, + )), + _ => match s.strip_prefix("aws_").unwrap_or(s).parse() { + Ok(key) => Ok(Self::Client(key)), + Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), + }, + } + } +} + +impl AmazonS3Builder { + /// Create a new [`AmazonS3Builder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Fill the [`AmazonS3Builder`] with regular AWS environment variables + /// + /// All environment variables starting with `AWS_` will be evaluated. Names must + /// match acceptable input to [`AmazonS3ConfigKey::from_str`]. Only upper-case environment + /// variables are accepted. + /// + /// Some examples of variables extracted from environment: + /// * `AWS_ACCESS_KEY_ID` -> access_key_id + /// * `AWS_SECRET_ACCESS_KEY` -> secret_access_key + /// * `AWS_DEFAULT_REGION` -> region + /// * `AWS_ENDPOINT` -> endpoint + /// * `AWS_SESSION_TOKEN` -> token + /// * `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` -> + /// * `AWS_ALLOW_HTTP` -> set to "true" to permit HTTP connections without TLS + /// * `AWS_REQUEST_PAYER` -> set to "true" to permit operations on requester-pays buckets. + /// # Example + /// ``` + /// use object_store::aws::AmazonS3Builder; + /// + /// let s3 = AmazonS3Builder::from_env() + /// .with_bucket_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder: Self = Default::default(); + + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if key.starts_with("AWS_") { + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + builder = builder.with_config(config_key, value); + } + } + } + } + + builder + } + + /// Parse available connection info form a well-known storage URL. + /// + /// The supported url schemes are: + /// + /// - `s3:///` + /// - `s3a:///` + /// - `https://s3..amazonaws.com/` + /// - `https://.s3..amazonaws.com` + /// - `https://ACCOUNT_ID.r2.cloudflarestorage.com/bucket` + /// + /// Note: Settings derived from the URL will override any others set on this builder + /// + /// # Example + /// ``` + /// use object_store::aws::AmazonS3Builder; + /// + /// let s3 = AmazonS3Builder::from_env() + /// .with_url("s3://bucket/path") + /// .build(); + /// ``` + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set an option on the builder via a key - value pair. + pub fn with_config(mut self, key: AmazonS3ConfigKey, value: impl Into) -> Self { + match key { + AmazonS3ConfigKey::AccessKeyId => self.access_key_id = Some(value.into()), + AmazonS3ConfigKey::SecretAccessKey => self.secret_access_key = Some(value.into()), + AmazonS3ConfigKey::Region => self.region = Some(value.into()), + AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()), + AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()), + AmazonS3ConfigKey::Token => self.token = Some(value.into()), + AmazonS3ConfigKey::ImdsV1Fallback => self.imdsv1_fallback.parse(value), + AmazonS3ConfigKey::VirtualHostedStyleRequest => { + self.virtual_hosted_style_request.parse(value) + } + AmazonS3ConfigKey::S3Express => self.s3_express.parse(value), + AmazonS3ConfigKey::DefaultRegion => { + self.region = self.region.or_else(|| Some(value.into())) + } + AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint = Some(value.into()), + AmazonS3ConfigKey::UnsignedPayload => self.unsigned_payload.parse(value), + AmazonS3ConfigKey::Checksum => { + self.checksum_algorithm = Some(ConfigValue::Deferred(value.into())) + } + AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { + self.container_credentials_relative_uri = Some(value.into()) + } + AmazonS3ConfigKey::Client(key) => { + self.client_options = self.client_options.with_config(key, value) + } + AmazonS3ConfigKey::SkipSignature => self.skip_signature.parse(value), + AmazonS3ConfigKey::DisableTagging => self.disable_tagging.parse(value), + AmazonS3ConfigKey::CopyIfNotExists => { + self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) + } + AmazonS3ConfigKey::ConditionalPut => { + self.conditional_put = ConfigValue::Deferred(value.into()) + } + AmazonS3ConfigKey::RequestPayer => { + self.request_payer = ConfigValue::Deferred(value.into()) + } + AmazonS3ConfigKey::Encryption(key) => match key { + S3EncryptionConfigKey::ServerSideEncryption => { + self.encryption_type = Some(ConfigValue::Deferred(value.into())) + } + S3EncryptionConfigKey::KmsKeyId => self.encryption_kms_key_id = Some(value.into()), + S3EncryptionConfigKey::BucketKeyEnabled => { + self.encryption_bucket_key_enabled = Some(ConfigValue::Deferred(value.into())) + } + S3EncryptionConfigKey::CustomerEncryptionKey => { + self.encryption_customer_key_base64 = Some(value.into()) + } + }, + }; + self + } + + /// Get config value via a [`AmazonS3ConfigKey`]. + /// + /// # Example + /// ``` + /// use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey}; + /// + /// let builder = AmazonS3Builder::from_env() + /// .with_bucket_name("foo"); + /// let bucket_name = builder.get_config_value(&AmazonS3ConfigKey::Bucket).unwrap_or_default(); + /// assert_eq!("foo", &bucket_name); + /// ``` + pub fn get_config_value(&self, key: &AmazonS3ConfigKey) -> Option { + match key { + AmazonS3ConfigKey::AccessKeyId => self.access_key_id.clone(), + AmazonS3ConfigKey::SecretAccessKey => self.secret_access_key.clone(), + AmazonS3ConfigKey::Region | AmazonS3ConfigKey::DefaultRegion => self.region.clone(), + AmazonS3ConfigKey::Bucket => self.bucket_name.clone(), + AmazonS3ConfigKey::Endpoint => self.endpoint.clone(), + AmazonS3ConfigKey::Token => self.token.clone(), + AmazonS3ConfigKey::ImdsV1Fallback => Some(self.imdsv1_fallback.to_string()), + AmazonS3ConfigKey::VirtualHostedStyleRequest => { + Some(self.virtual_hosted_style_request.to_string()) + } + AmazonS3ConfigKey::S3Express => Some(self.s3_express.to_string()), + AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(), + AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()), + AmazonS3ConfigKey::Checksum => { + self.checksum_algorithm.as_ref().map(ToString::to_string) + } + AmazonS3ConfigKey::Client(key) => self.client_options.get_config_value(key), + AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { + self.container_credentials_relative_uri.clone() + } + AmazonS3ConfigKey::SkipSignature => Some(self.skip_signature.to_string()), + AmazonS3ConfigKey::CopyIfNotExists => { + self.copy_if_not_exists.as_ref().map(ToString::to_string) + } + AmazonS3ConfigKey::ConditionalPut => Some(self.conditional_put.to_string()), + AmazonS3ConfigKey::DisableTagging => Some(self.disable_tagging.to_string()), + AmazonS3ConfigKey::RequestPayer => Some(self.request_payer.to_string()), + AmazonS3ConfigKey::Encryption(key) => match key { + S3EncryptionConfigKey::ServerSideEncryption => { + self.encryption_type.as_ref().map(ToString::to_string) + } + S3EncryptionConfigKey::KmsKeyId => self.encryption_kms_key_id.clone(), + S3EncryptionConfigKey::BucketKeyEnabled => self + .encryption_bucket_key_enabled + .as_ref() + .map(ToString::to_string), + S3EncryptionConfigKey::CustomerEncryptionKey => { + self.encryption_customer_key_base64.clone() + } + }, + } + } + + /// Sets properties on this builder based on a URL + /// + /// This is a separate member function to allow fallible computation to + /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] + fn parse_url(&mut self, url: &str) -> Result<()> { + let parsed = Url::parse(url).map_err(|source| { + let url = url.into(); + Error::UnableToParseUrl { url, source } + })?; + + let host = parsed + .host_str() + .ok_or_else(|| Error::UrlNotRecognised { url: url.into() })?; + + match parsed.scheme() { + "s3" | "s3a" => self.bucket_name = Some(host.to_string()), + "https" => match host.splitn(4, '.').collect_tuple() { + Some(("s3", region, "amazonaws", "com")) => { + self.region = Some(region.to_string()); + let bucket = parsed.path_segments().into_iter().flatten().next(); + if let Some(bucket) = bucket { + self.bucket_name = Some(bucket.into()); + } + } + Some((bucket, "s3", region, "amazonaws.com")) => { + self.bucket_name = Some(bucket.to_string()); + self.region = Some(region.to_string()); + self.virtual_hosted_style_request = true.into(); + } + Some((account, "r2", "cloudflarestorage", "com")) => { + self.region = Some("auto".to_string()); + let endpoint = format!("https://{account}.r2.cloudflarestorage.com"); + self.endpoint = Some(endpoint); + + let bucket = parsed.path_segments().into_iter().flatten().next(); + if let Some(bucket) = bucket { + self.bucket_name = Some(bucket.into()); + } + } + _ => return Err(Error::UrlNotRecognised { url: url.into() }.into()), + }, + scheme => { + let scheme = scheme.into(); + return Err(Error::UnknownUrlScheme { scheme }.into()); + } + }; + Ok(()) + } + + /// Set the AWS Access Key + pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); + self + } + + /// Set the AWS Secret Access Key + pub fn with_secret_access_key(mut self, secret_access_key: impl Into) -> Self { + self.secret_access_key = Some(secret_access_key.into()); + self + } + + /// Set the AWS Session Token to use for requests + pub fn with_token(mut self, token: impl Into) -> Self { + self.token = Some(token.into()); + self + } + + /// Set the region, defaults to `us-east-1` + pub fn with_region(mut self, region: impl Into) -> Self { + self.region = Some(region.into()); + self + } + + /// Set the bucket_name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Sets the endpoint for communicating with AWS S3, defaults to the [region endpoint] + /// + /// For example, this might be set to `"http://localhost:4566:` + /// for testing against a localstack instance. + /// + /// The `endpoint` field should be consistent with [`Self::with_virtual_hosted_style_request`], + /// i.e. if `virtual_hosted_style_request` is set to true then `endpoint` + /// should have the bucket name included. + /// + /// By default, only HTTPS schemes are enabled. To connect to an HTTP endpoint, enable + /// [`Self::with_allow_http`]. + /// + /// [region endpoint]: https://docs.aws.amazon.com/general/latest/gr/s3.html + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = Some(endpoint.into()); + self + } + + /// Set the credential provider overriding any other options + pub fn with_credentials(mut self, credentials: AwsCredentialProvider) -> Self { + self.credentials = Some(credentials); + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.client_options = self.client_options.with_allow_http(allow_http); + self + } + + /// Sets if virtual hosted style request has to be used. + /// + /// If `virtual_hosted_style_request` is: + /// * false (default): Path style request is used + /// * true: Virtual hosted style request is used + /// + /// If the `endpoint` is provided then it should be + /// consistent with `virtual_hosted_style_request`. + /// i.e. if `virtual_hosted_style_request` is set to true + /// then `endpoint` should have bucket name included. + pub fn with_virtual_hosted_style_request(mut self, virtual_hosted_style_request: bool) -> Self { + self.virtual_hosted_style_request = virtual_hosted_style_request.into(); + self + } + + /// Configure this as an S3 Express One Zone Bucket + pub fn with_s3_express(mut self, s3_express: bool) -> Self { + self.s3_express = s3_express.into(); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends + /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack] + /// + /// However, certain deployment environments, such as those running old versions of kube2iam, + /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1 + /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported. + /// + /// This option has no effect if not using instance credentials + /// + /// [IMDSv2]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html + /// [SSRF attack]: https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/ + /// + pub fn with_imdsv1_fallback(mut self) -> Self { + self.imdsv1_fallback = true.into(); + self + } + + /// Sets if unsigned payload option has to be used. + /// See [unsigned payload option](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html) + /// * false (default): Signed payload option is used, where the checksum for the request body is computed and included when constructing a canonical request. + /// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is included when constructing a canonical request, + pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self { + self.unsigned_payload = unsigned_payload.into(); + self + } + + /// If enabled, [`AmazonS3`] will not fetch credentials and will not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny authorized requests + pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { + self.skip_signature = skip_signature.into(); + self + } + + /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. + /// + /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self { + // Convert to String to enable deferred parsing of config + self.checksum_algorithm = Some(checksum_algorithm.into()); + self + } + + /// Set the [instance metadata endpoint](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html), + /// used primarily within AWS EC2. + /// + /// This defaults to the IPv4 endpoint: http://169.254.169.254. One can alternatively use the IPv6 + /// endpoint http://fd00:ec2::254. + pub fn with_metadata_endpoint(mut self, endpoint: impl Into) -> Self { + self.metadata_endpoint = Some(endpoint.into()); + self + } + + /// Set the proxy_url to be used by the underlying client + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_url(proxy_url); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate(mut self, proxy_ca_certificate: impl Into) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// Configure how to provide `copy_if_not_exists` + pub fn with_copy_if_not_exists(mut self, config: S3CopyIfNotExists) -> Self { + self.copy_if_not_exists = Some(config.into()); + self + } + + /// Configure how to provide conditional put operations. + /// if not set, the default value will be `S3ConditionalPut::ETagMatch` + pub fn with_conditional_put(mut self, config: S3ConditionalPut) -> Self { + self.conditional_put = config.into(); + self + } + + /// If set to `true` will ignore any tags provided to put_opts + pub fn with_disable_tagging(mut self, ignore: bool) -> Self { + self.disable_tagging = ignore.into(); + self + } + + /// Use SSE-KMS for server side encryption. + pub fn with_sse_kms_encryption(mut self, kms_key_id: impl Into) -> Self { + self.encryption_type = Some(ConfigValue::Parsed(S3EncryptionType::SseKms)); + if let Some(kms_key_id) = kms_key_id.into().into() { + self.encryption_kms_key_id = Some(kms_key_id); + } + self + } + + /// Use dual server side encryption for server side encryption. + pub fn with_dsse_kms_encryption(mut self, kms_key_id: impl Into) -> Self { + self.encryption_type = Some(ConfigValue::Parsed(S3EncryptionType::DsseKms)); + if let Some(kms_key_id) = kms_key_id.into().into() { + self.encryption_kms_key_id = Some(kms_key_id); + } + self + } + + /// Use SSE-C for server side encryption. + /// Must pass the *base64-encoded* 256-bit customer encryption key. + pub fn with_ssec_encryption(mut self, customer_key_base64: impl Into) -> Self { + self.encryption_type = Some(ConfigValue::Parsed(S3EncryptionType::SseC)); + self.encryption_customer_key_base64 = customer_key_base64.into().into(); + self + } + + /// Set whether to enable bucket key for server side encryption. This overrides + /// the bucket default setting for bucket keys. + /// + /// When bucket keys are disabled, each object is encrypted with a unique data key. + /// When bucket keys are enabled, a single data key is used for the entire bucket, + /// reducing overhead of encryption. + pub fn with_bucket_key(mut self, enabled: bool) -> Self { + self.encryption_bucket_key_enabled = Some(ConfigValue::Parsed(enabled)); + self + } + + /// Set whether to charge requester for bucket operations. + /// + /// + pub fn with_request_payer(mut self, enabled: bool) -> Self { + self.request_payer = ConfigValue::Parsed(enabled); + self + } + + /// The [`HttpConnector`] to use + /// + /// On non-WASM32 platforms uses [`reqwest`] by default, on WASM32 platforms must be provided + pub fn with_http_connector(mut self, connector: C) -> Self { + self.http_connector = Some(Arc::new(connector)); + self + } + + /// Create a [`AmazonS3`] instance from the provided values, + /// consuming `self`. + pub fn build(mut self) -> Result { + if let Some(url) = self.url.take() { + self.parse_url(&url)?; + } + + let http = http_connector(self.http_connector)?; + + let bucket = self.bucket_name.ok_or(Error::MissingBucketName)?; + let region = self.region.unwrap_or_else(|| "us-east-1".to_string()); + let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; + let copy_if_not_exists = self.copy_if_not_exists.map(|x| x.get()).transpose()?; + + let credentials = if let Some(credentials) = self.credentials { + credentials + } else if self.access_key_id.is_some() || self.secret_access_key.is_some() { + match (self.access_key_id, self.secret_access_key, self.token) { + (Some(key_id), Some(secret_key), token) => { + info!("Using Static credential provider"); + let credential = AwsCredential { + key_id, + secret_key, + token, + }; + Arc::new(StaticCredentialProvider::new(credential)) as _ + } + (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()), + (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), + (None, None, _) => unreachable!(), + } + } else if let (Ok(token_path), Ok(role_arn)) = ( + std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"), + std::env::var("AWS_ROLE_ARN"), + ) { + // TODO: Replace with `AmazonS3Builder::credentials_from_env` + info!("Using WebIdentity credential provider"); + + let session_name = std::env::var("AWS_ROLE_SESSION_NAME") + .unwrap_or_else(|_| "WebIdentitySession".to_string()); + + let endpoint = format!("https://sts.{region}.amazonaws.com"); + + // Disallow non-HTTPs requests + let options = self.client_options.clone().with_allow_http(false); + + let token = WebIdentityProvider { + token_path, + session_name, + role_arn, + endpoint, + }; + + Arc::new(TokenCredentialProvider::new( + token, + http.connect(&options)?, + self.retry_config.clone(), + )) as _ + } else if let Some(uri) = self.container_credentials_relative_uri { + info!("Using Task credential provider"); + + let options = self.client_options.clone().with_allow_http(true); + + Arc::new(TaskCredentialProvider { + url: format!("http://169.254.170.2{uri}"), + retry: self.retry_config.clone(), + // The instance metadata endpoint is access over HTTP + client: http.connect(&options)?, + cache: Default::default(), + }) as _ + } else { + info!("Using Instance credential provider"); + + let token = InstanceCredentialProvider { + imdsv1_fallback: self.imdsv1_fallback.get()?, + metadata_endpoint: self + .metadata_endpoint + .unwrap_or_else(|| DEFAULT_METADATA_ENDPOINT.into()), + }; + + Arc::new(TokenCredentialProvider::new( + token, + http.connect(&self.client_options.metadata_options())?, + self.retry_config.clone(), + )) as _ + }; + + let (session_provider, zonal_endpoint) = match self.s3_express.get()? { + true => { + let zone = parse_bucket_az(&bucket).ok_or_else(|| { + let bucket = bucket.clone(); + Error::ZoneSuffix { bucket } + })?; + + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-express-Regions-and-Zones.html + let endpoint = format!("https://{bucket}.s3express-{zone}.{region}.amazonaws.com"); + + let session = Arc::new( + TokenCredentialProvider::new( + SessionProvider { + endpoint: endpoint.clone(), + region: region.clone(), + credentials: Arc::clone(&credentials), + }, + http.connect(&self.client_options)?, + self.retry_config.clone(), + ) + .with_min_ttl(Duration::from_secs(60)), // Credentials only valid for 5 minutes + ); + (Some(session as _), Some(endpoint)) + } + false => (None, None), + }; + + // If `endpoint` is provided it's assumed to be consistent with `virtual_hosted_style_request` or `s3_express`. + // For example, if `virtual_hosted_style_request` is true then `endpoint` should have bucket name included. + let virtual_hosted = self.virtual_hosted_style_request.get()?; + let bucket_endpoint = match (&self.endpoint, zonal_endpoint, virtual_hosted) { + (Some(endpoint), _, true) => endpoint.clone(), + (Some(endpoint), _, false) => format!("{}/{}", endpoint.trim_end_matches("/"), bucket), + (None, Some(endpoint), _) => endpoint, + (None, None, true) => format!("https://{bucket}.s3.{region}.amazonaws.com"), + (None, None, false) => format!("https://s3.{region}.amazonaws.com/{bucket}"), + }; + + let encryption_headers = if let Some(encryption_type) = self.encryption_type { + S3EncryptionHeaders::try_new( + &encryption_type.get()?, + self.encryption_kms_key_id, + self.encryption_bucket_key_enabled + .map(|val| val.get()) + .transpose()?, + self.encryption_customer_key_base64, + )? + } else { + S3EncryptionHeaders::default() + }; + + let config = S3Config { + region, + endpoint: self.endpoint, + bucket, + bucket_endpoint, + credentials, + session_provider, + retry_config: self.retry_config, + client_options: self.client_options, + sign_payload: !self.unsigned_payload.get()?, + skip_signature: self.skip_signature.get()?, + disable_tagging: self.disable_tagging.get()?, + checksum, + copy_if_not_exists, + conditional_put: self.conditional_put.get()?, + encryption_headers, + request_payer: self.request_payer.get()?, + }; + + let http_client = http.connect(&config.client_options)?; + let client = Arc::new(S3Client::new(config, http_client)); + + Ok(AmazonS3 { client }) + } +} + +/// Extracts the AZ from a S3 Express One Zone bucket name +/// +/// +fn parse_bucket_az(bucket: &str) -> Option<&str> { + Some(bucket.strip_suffix("--x-s3")?.rsplit_once("--")?.1) +} + +/// Encryption configuration options for S3. +/// +/// These options are used to configure server-side encryption for S3 objects. +/// To configure them, pass them to [`AmazonS3Builder::with_config`]. +/// +/// [SSE-S3]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingServerSideEncryption.html +/// [SSE-KMS]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingKMSEncryption.html +/// [DSSE-KMS]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingDSSEncryption.html +/// [SSE-C]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] +#[non_exhaustive] +pub enum S3EncryptionConfigKey { + /// Type of encryption to use. If set, must be one of "AES256" (SSE-S3), "aws:kms" (SSE-KMS), "aws:kms:dsse" (DSSE-KMS) or "sse-c". + ServerSideEncryption, + /// The KMS key ID to use for server-side encryption. If set, ServerSideEncryption + /// must be "aws:kms" or "aws:kms:dsse". + KmsKeyId, + /// If set to true, will use the bucket's default KMS key for server-side encryption. + /// If set to false, will disable the use of the bucket's default KMS key for server-side encryption. + BucketKeyEnabled, + + /// The base64 encoded, 256-bit customer encryption key to use for server-side encryption. + /// If set, ServerSideEncryption must be "sse-c". + CustomerEncryptionKey, +} + +impl AsRef for S3EncryptionConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::ServerSideEncryption => "aws_server_side_encryption", + Self::KmsKeyId => "aws_sse_kms_key_id", + Self::BucketKeyEnabled => "aws_sse_bucket_key_enabled", + Self::CustomerEncryptionKey => "aws_sse_customer_key_base64", + } + } +} + +#[derive(Debug, Clone)] +enum S3EncryptionType { + S3, + SseKms, + DsseKms, + SseC, +} + +impl crate::config::Parse for S3EncryptionType { + fn parse(s: &str) -> Result { + match s { + "AES256" => Ok(Self::S3), + "aws:kms" => Ok(Self::SseKms), + "aws:kms:dsse" => Ok(Self::DsseKms), + "sse-c" => Ok(Self::SseC), + _ => Err(Error::InvalidEncryptionType { passed: s.into() }.into()), + } + } +} + +impl From<&S3EncryptionType> for &'static str { + fn from(value: &S3EncryptionType) -> Self { + match value { + S3EncryptionType::S3 => "AES256", + S3EncryptionType::SseKms => "aws:kms", + S3EncryptionType::DsseKms => "aws:kms:dsse", + S3EncryptionType::SseC => "sse-c", + } + } +} + +impl std::fmt::Display for S3EncryptionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.into()) + } +} + +/// A sequence of headers to be sent for write requests that specify server-side +/// encryption. +/// +/// Whether these headers are sent depends on both the kind of encryption set +/// and the kind of request being made. +#[derive(Default, Clone, Debug)] +pub(super) struct S3EncryptionHeaders(pub HeaderMap); + +impl S3EncryptionHeaders { + fn try_new( + encryption_type: &S3EncryptionType, + encryption_kms_key_id: Option, + bucket_key_enabled: Option, + encryption_customer_key_base64: Option, + ) -> Result { + let mut headers = HeaderMap::new(); + match encryption_type { + S3EncryptionType::S3 | S3EncryptionType::SseKms | S3EncryptionType::DsseKms => { + headers.insert( + "x-amz-server-side-encryption", + HeaderValue::from_static(encryption_type.into()), + ); + if let Some(key_id) = encryption_kms_key_id { + headers.insert( + "x-amz-server-side-encryption-aws-kms-key-id", + key_id + .try_into() + .map_err(|err| Error::InvalidEncryptionHeader { + header: "kms-key-id", + source: Box::new(err), + })?, + ); + } + if let Some(bucket_key_enabled) = bucket_key_enabled { + headers.insert( + "x-amz-server-side-encryption-bucket-key-enabled", + HeaderValue::from_static(if bucket_key_enabled { "true" } else { "false" }), + ); + } + } + S3EncryptionType::SseC => { + headers.insert( + "x-amz-server-side-encryption-customer-algorithm", + HeaderValue::from_static("AES256"), + ); + if let Some(key) = encryption_customer_key_base64 { + let mut header_value: HeaderValue = + key.clone() + .try_into() + .map_err(|err| Error::InvalidEncryptionHeader { + header: "x-amz-server-side-encryption-customer-key", + source: Box::new(err), + })?; + header_value.set_sensitive(true); + headers.insert("x-amz-server-side-encryption-customer-key", header_value); + + let decoded_key = BASE64_STANDARD.decode(key.as_bytes()).map_err(|err| { + Error::InvalidEncryptionHeader { + header: "x-amz-server-side-encryption-customer-key", + source: Box::new(err), + } + })?; + let mut hasher = Md5::new(); + hasher.update(decoded_key); + let md5 = BASE64_STANDARD.encode(hasher.finalize()); + let mut md5_header_value: HeaderValue = + md5.try_into() + .map_err(|err| Error::InvalidEncryptionHeader { + header: "x-amz-server-side-encryption-customer-key-MD5", + source: Box::new(err), + })?; + md5_header_value.set_sensitive(true); + headers.insert( + "x-amz-server-side-encryption-customer-key-MD5", + md5_header_value, + ); + } else { + return Err(Error::InvalidEncryptionHeader { + header: "x-amz-server-side-encryption-customer-key", + source: Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Missing customer key", + )), + } + .into()); + } + } + } + Ok(Self(headers)) + } +} + +impl From for HeaderMap { + fn from(headers: S3EncryptionHeaders) -> Self { + headers.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn s3_test_config_from_map() { + let aws_access_key_id = "object_store:fake_access_key_id".to_string(); + let aws_secret_access_key = "object_store:fake_secret_key".to_string(); + let aws_default_region = "object_store:fake_default_region".to_string(); + let aws_endpoint = "object_store:fake_endpoint".to_string(); + let aws_session_token = "object_store:fake_session_token".to_string(); + let options = HashMap::from([ + ("aws_access_key_id", aws_access_key_id.clone()), + ("aws_secret_access_key", aws_secret_access_key), + ("aws_default_region", aws_default_region.clone()), + ("aws_endpoint", aws_endpoint.clone()), + ("aws_session_token", aws_session_token.clone()), + ("aws_unsigned_payload", "true".to_string()), + ("aws_checksum_algorithm", "sha256".to_string()), + ]); + + let builder = options + .into_iter() + .fold(AmazonS3Builder::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }) + .with_config(AmazonS3ConfigKey::SecretAccessKey, "new-secret-key"); + + assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str()); + assert_eq!(builder.secret_access_key.unwrap(), "new-secret-key"); + assert_eq!(builder.region.unwrap(), aws_default_region); + assert_eq!(builder.endpoint.unwrap(), aws_endpoint); + assert_eq!(builder.token.unwrap(), aws_session_token); + assert_eq!( + builder.checksum_algorithm.unwrap().get().unwrap(), + Checksum::SHA256 + ); + assert!(builder.unsigned_payload.get().unwrap()); + } + + #[test] + fn s3_test_config_get_value() { + let aws_access_key_id = "object_store:fake_access_key_id".to_string(); + let aws_secret_access_key = "object_store:fake_secret_key".to_string(); + let aws_default_region = "object_store:fake_default_region".to_string(); + let aws_endpoint = "object_store:fake_endpoint".to_string(); + let aws_session_token = "object_store:fake_session_token".to_string(); + + let builder = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::AccessKeyId, &aws_access_key_id) + .with_config(AmazonS3ConfigKey::SecretAccessKey, &aws_secret_access_key) + .with_config(AmazonS3ConfigKey::DefaultRegion, &aws_default_region) + .with_config(AmazonS3ConfigKey::Endpoint, &aws_endpoint) + .with_config(AmazonS3ConfigKey::Token, &aws_session_token) + .with_config(AmazonS3ConfigKey::UnsignedPayload, "true") + .with_config("aws_server_side_encryption".parse().unwrap(), "AES256") + .with_config("aws_sse_kms_key_id".parse().unwrap(), "some_key_id") + .with_config("aws_sse_bucket_key_enabled".parse().unwrap(), "true") + .with_config( + "aws_sse_customer_key_base64".parse().unwrap(), + "some_customer_key", + ); + + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::AccessKeyId) + .unwrap(), + aws_access_key_id + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::SecretAccessKey) + .unwrap(), + aws_secret_access_key + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::DefaultRegion) + .unwrap(), + aws_default_region + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::Endpoint) + .unwrap(), + aws_endpoint + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Token).unwrap(), + aws_session_token + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::UnsignedPayload) + .unwrap(), + "true" + ); + assert_eq!( + builder + .get_config_value(&"aws_server_side_encryption".parse().unwrap()) + .unwrap(), + "AES256" + ); + assert_eq!( + builder + .get_config_value(&"aws_sse_kms_key_id".parse().unwrap()) + .unwrap(), + "some_key_id" + ); + assert_eq!( + builder + .get_config_value(&"aws_sse_bucket_key_enabled".parse().unwrap()) + .unwrap(), + "true" + ); + assert_eq!( + builder + .get_config_value(&"aws_sse_customer_key_base64".parse().unwrap()) + .unwrap(), + "some_customer_key" + ); + } + + #[test] + fn s3_default_region() { + let builder = AmazonS3Builder::new() + .with_bucket_name("foo") + .build() + .unwrap(); + assert_eq!(builder.client.config.region, "us-east-1"); + } + + #[test] + fn s3_test_bucket_endpoint() { + let builder = AmazonS3Builder::new() + .with_endpoint("http://some.host:1234") + .with_bucket_name("foo") + .build() + .unwrap(); + assert_eq!( + builder.client.config.bucket_endpoint, + "http://some.host:1234/foo" + ); + + let builder = AmazonS3Builder::new() + .with_endpoint("http://some.host:1234/") + .with_bucket_name("foo") + .build() + .unwrap(); + assert_eq!( + builder.client.config.bucket_endpoint, + "http://some.host:1234/foo" + ); + } + + #[test] + fn s3_test_urls() { + let mut builder = AmazonS3Builder::new(); + builder.parse_url("s3://bucket/path").unwrap(); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("s3://buckets.can.have.dots/path") + .unwrap(); + assert_eq!( + builder.bucket_name, + Some("buckets.can.have.dots".to_string()) + ); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://s3.region.amazonaws.com") + .unwrap(); + assert_eq!(builder.region, Some("region".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://s3.region.amazonaws.com/bucket") + .unwrap(); + assert_eq!(builder.region, Some("region".to_string())); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://s3.region.amazonaws.com/bucket.with.dot/path") + .unwrap(); + assert_eq!(builder.region, Some("region".to_string())); + assert_eq!(builder.bucket_name, Some("bucket.with.dot".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://bucket.s3.region.amazonaws.com") + .unwrap(); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + assert_eq!(builder.region, Some("region".to_string())); + assert!(builder.virtual_hosted_style_request.get().unwrap()); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://account123.r2.cloudflarestorage.com/bucket-123") + .unwrap(); + + assert_eq!(builder.bucket_name, Some("bucket-123".to_string())); + assert_eq!(builder.region, Some("auto".to_string())); + assert_eq!( + builder.endpoint, + Some("https://account123.r2.cloudflarestorage.com".to_string()) + ); + + let err_cases = [ + "mailto://bucket/path", + "https://s3.bucket.mydomain.com", + "https://s3.bucket.foo.amazonaws.com", + "https://bucket.mydomain.region.amazonaws.com", + "https://bucket.s3.region.bar.amazonaws.com", + "https://bucket.foo.s3.amazonaws.com", + ]; + let mut builder = AmazonS3Builder::new(); + for case in err_cases { + builder.parse_url(case).unwrap_err(); + } + } + + #[tokio::test] + async fn s3_test_proxy_url() { + let s3 = AmazonS3Builder::new() + .with_access_key_id("access_key_id") + .with_secret_access_key("secret_access_key") + .with_region("region") + .with_bucket_name("bucket_name") + .with_allow_http(true) + .with_proxy_url("https://example.com") + .build(); + + assert!(s3.is_ok()); + + let err = AmazonS3Builder::new() + .with_access_key_id("access_key_id") + .with_secret_access_key("secret_access_key") + .with_region("region") + .with_bucket_name("bucket_name") + .with_allow_http(true) + .with_proxy_url("asdf://example.com") + .build() + .unwrap_err() + .to_string(); + + assert_eq!("Generic HTTP client error: builder error", err); + } + + #[test] + fn test_invalid_config() { + let err = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled") + .with_bucket_name("bucket") + .with_region("region") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Generic Config error: failed to parse \"enabled\" as boolean" + ); + + let err = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::Checksum, "md5") + .with_bucket_name("bucket") + .with_region("region") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Generic Config error: \"md5\" is not a valid checksum algorithm" + ); + } + + #[test] + fn test_parse_bucket_az() { + let cases = [ + ("bucket-base-name--usw2-az1--x-s3", Some("usw2-az1")), + ("bucket-base--name--azid--x-s3", Some("azid")), + ("bucket-base-name", None), + ("bucket-base-name--x-s3", None), + ]; + + for (bucket, expected) in cases { + assert_eq!(parse_bucket_az(bucket), expected) + } + } + + #[test] + fn aws_test_client_opts() { + let key = "AWS_PROXY_URL"; + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + assert_eq!( + AmazonS3ConfigKey::Client(ClientConfigKey::ProxyUrl), + config_key + ); + } else { + panic!("{} not propagated as ClientConfigKey", key); + } + } +} diff --git a/src/aws/checksum.rs b/src/aws/checksum.rs new file mode 100644 index 0000000..d15bbf0 --- /dev/null +++ b/src/aws/checksum.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::config::Parse; +use std::str::FromStr; + +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Enum representing checksum algorithm supported by S3. +pub enum Checksum { + /// SHA-256 algorithm. + SHA256, +} + +impl std::fmt::Display for Checksum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self { + Self::SHA256 => write!(f, "sha256"), + } + } +} + +impl FromStr for Checksum { + type Err = (); + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "sha256" => Ok(Self::SHA256), + _ => Err(()), + } + } +} + +impl TryFrom<&String> for Checksum { + type Error = (); + + fn try_from(value: &String) -> Result { + value.parse() + } +} + +impl Parse for Checksum { + fn parse(v: &str) -> crate::Result { + v.parse().map_err(|_| crate::Error::Generic { + store: "Config", + source: format!("\"{v}\" is not a valid checksum algorithm").into(), + }) + } +} diff --git a/src/aws/client.rs b/src/aws/client.rs new file mode 100644 index 0000000..fb2a033 --- /dev/null +++ b/src/aws/client.rs @@ -0,0 +1,931 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::builder::S3EncryptionHeaders; +use crate::aws::checksum::Checksum; +use crate::aws::credential::{AwsCredential, CredentialExt}; +use crate::aws::{ + AwsAuthorizer, AwsCredentialProvider, S3ConditionalPut, S3CopyIfNotExists, COPY_SOURCE_HEADER, + STORE, STRICT_PATH_ENCODE_SET, TAGS_HEADER, +}; +use crate::client::builder::{HttpRequestBuilder, RequestBuilderError}; +use crate::client::get::GetClient; +use crate::client::header::{get_etag, HeaderConfig}; +use crate::client::header::{get_put_result, get_version}; +use crate::client::list::ListClient; +use crate::client::retry::RetryExt; +use crate::client::s3::{ + CompleteMultipartUpload, CompleteMultipartUploadResult, CopyPartResult, + InitiateMultipartUploadResult, ListResponse, PartMetadata, +}; +use crate::client::{GetOptionsExt, HttpClient, HttpError, HttpResponse}; +use crate::multipart::PartId; +use crate::path::DELIMITER; +use crate::{ + Attribute, Attributes, ClientOptions, GetOptions, ListResult, MultipartId, Path, + PutMultipartOpts, PutPayload, PutResult, Result, RetryConfig, TagSet, +}; +use async_trait::async_trait; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::{Buf, Bytes}; +use http::header::{ + CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, + CONTENT_TYPE, +}; +use http::{HeaderMap, HeaderName, Method}; +use itertools::Itertools; +use md5::{Digest, Md5}; +use percent_encoding::{utf8_percent_encode, PercentEncode}; +use quick_xml::events::{self as xml_events}; +use ring::digest; +use ring::digest::Context; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +const VERSION_HEADER: &str = "x-amz-version-id"; +const SHA256_CHECKSUM: &str = "x-amz-checksum-sha256"; +const USER_DEFINED_METADATA_HEADER_PREFIX: &str = "x-amz-meta-"; +const ALGORITHM: &str = "x-amz-checksum-algorithm"; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, thiserror::Error)] +pub(crate) enum Error { + #[error("Error performing DeleteObjects request: {}", source)] + DeleteObjectsRequest { + source: crate::client::retry::RetryError, + }, + + #[error( + "DeleteObjects request failed for key {}: {} (code: {})", + path, + message, + code + )] + DeleteFailed { + path: String, + code: String, + message: String, + }, + + #[error("Error getting DeleteObjects response body: {}", source)] + DeleteObjectsResponse { source: HttpError }, + + #[error("Got invalid DeleteObjects response: {}", source)] + InvalidDeleteObjectsResponse { + source: Box, + }, + + #[error("Error performing list request: {}", source)] + ListRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting list response body: {}", source)] + ListResponseBody { source: HttpError }, + + #[error("Error getting create multipart response body: {}", source)] + CreateMultipartResponseBody { source: HttpError }, + + #[error("Error performing complete multipart request: {}: {}", path, source)] + CompleteMultipartRequest { + source: crate::client::retry::RetryError, + path: String, + }, + + #[error("Error getting complete multipart response body: {}", source)] + CompleteMultipartResponseBody { source: HttpError }, + + #[error("Got invalid list response: {}", source)] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[error("Got invalid multipart response: {}", source)] + InvalidMultipartResponse { source: quick_xml::de::DeError }, + + #[error("Unable to extract metadata from headers: {}", source)] + Metadata { + source: crate::client::header::Error, + }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::CompleteMultipartRequest { source, path } => source.error(STORE, path), + _ => Self::Generic { + store: STORE, + source: Box::new(err), + }, + } + } +} + +pub(crate) enum PutPartPayload<'a> { + Part(PutPayload), + Copy(&'a Path), +} + +impl Default for PutPartPayload<'_> { + fn default() -> Self { + Self::Part(PutPayload::default()) + } +} + +pub(crate) enum CompleteMultipartMode { + Overwrite, + Create, +} + +#[derive(Deserialize)] +#[serde(rename_all = "PascalCase", rename = "DeleteResult")] +struct BatchDeleteResponse { + #[serde(rename = "$value")] + content: Vec, +} + +#[derive(Deserialize)] +enum DeleteObjectResult { + #[allow(unused)] + Deleted(DeletedObject), + Error(DeleteError), +} + +#[derive(Deserialize)] +#[serde(rename_all = "PascalCase", rename = "Deleted")] +struct DeletedObject { + #[allow(dead_code)] + key: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "PascalCase", rename = "Error")] +struct DeleteError { + key: String, + code: String, + message: String, +} + +impl From for Error { + fn from(err: DeleteError) -> Self { + Self::DeleteFailed { + path: err.key, + code: err.code, + message: err.message, + } + } +} + +#[derive(Debug)] +pub(crate) struct S3Config { + pub region: String, + pub endpoint: Option, + pub bucket: String, + pub bucket_endpoint: String, + pub credentials: AwsCredentialProvider, + pub session_provider: Option, + pub retry_config: RetryConfig, + pub client_options: ClientOptions, + pub sign_payload: bool, + pub skip_signature: bool, + pub disable_tagging: bool, + pub checksum: Option, + pub copy_if_not_exists: Option, + pub conditional_put: S3ConditionalPut, + pub request_payer: bool, + pub(super) encryption_headers: S3EncryptionHeaders, +} + +impl S3Config { + pub(crate) fn path_url(&self, path: &Path) -> String { + format!("{}/{}", self.bucket_endpoint, encode_path(path)) + } + + async fn get_session_credential(&self) -> Result> { + let credential = match self.skip_signature { + false => { + let provider = self.session_provider.as_ref().unwrap_or(&self.credentials); + Some(provider.get_credential().await?) + } + true => None, + }; + + Ok(SessionCredential { + credential, + session_token: self.session_provider.is_some(), + config: self, + }) + } + + pub(crate) async fn get_credential(&self) -> Result>> { + Ok(match self.skip_signature { + false => Some(self.credentials.get_credential().await?), + true => None, + }) + } + + #[inline] + pub(crate) fn is_s3_express(&self) -> bool { + self.session_provider.is_some() + } +} + +struct SessionCredential<'a> { + credential: Option>, + session_token: bool, + config: &'a S3Config, +} + +impl SessionCredential<'_> { + fn authorizer(&self) -> Option> { + let mut authorizer = + AwsAuthorizer::new(self.credential.as_deref()?, "s3", &self.config.region) + .with_sign_payload(self.config.sign_payload) + .with_request_payer(self.config.request_payer); + + if self.session_token { + let token = HeaderName::from_static("x-amz-s3session-token"); + authorizer = authorizer.with_token_header(token) + } + + Some(authorizer) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RequestError { + #[error(transparent)] + Generic { + #[from] + source: crate::Error, + }, + + #[error("Retry")] + Retry { + source: crate::client::retry::RetryError, + path: String, + }, +} + +impl From for crate::Error { + fn from(value: RequestError) -> Self { + match value { + RequestError::Generic { source } => source, + RequestError::Retry { source, path } => source.error(STORE, path), + } + } +} + +/// A builder for a request allowing customisation of the headers and query string +pub(crate) struct Request<'a> { + path: &'a Path, + config: &'a S3Config, + builder: HttpRequestBuilder, + payload_sha256: Option, + payload: Option, + use_session_creds: bool, + idempotent: bool, + retry_on_conflict: bool, + retry_error_body: bool, +} + +impl Request<'_> { + pub(crate) fn query(self, query: &T) -> Self { + let builder = self.builder.query(query); + Self { builder, ..self } + } + + pub(crate) fn header(self, k: K, v: &str) -> Self + where + K: TryInto, + K::Error: Into, + { + let builder = self.builder.header(k, v); + Self { builder, ..self } + } + + pub(crate) fn headers(self, headers: HeaderMap) -> Self { + let builder = self.builder.headers(headers); + Self { builder, ..self } + } + + pub(crate) fn idempotent(self, idempotent: bool) -> Self { + Self { idempotent, ..self } + } + + pub(crate) fn retry_on_conflict(self, retry_on_conflict: bool) -> Self { + Self { + retry_on_conflict, + ..self + } + } + + pub(crate) fn retry_error_body(self, retry_error_body: bool) -> Self { + Self { + retry_error_body, + ..self + } + } + + pub(crate) fn with_encryption_headers(self) -> Self { + let headers = self.config.encryption_headers.clone().into(); + let builder = self.builder.headers(headers); + Self { builder, ..self } + } + + pub(crate) fn with_session_creds(self, use_session_creds: bool) -> Self { + Self { + use_session_creds, + ..self + } + } + + pub(crate) fn with_tags(mut self, tags: TagSet) -> Self { + let tags = tags.encoded(); + if !tags.is_empty() && !self.config.disable_tagging { + self.builder = self.builder.header(&TAGS_HEADER, tags); + } + self + } + + pub(crate) fn with_attributes(self, attributes: Attributes) -> Self { + let mut has_content_type = false; + let mut builder = self.builder; + for (k, v) in &attributes { + builder = match k { + Attribute::CacheControl => builder.header(CACHE_CONTROL, v.as_ref()), + Attribute::ContentDisposition => builder.header(CONTENT_DISPOSITION, v.as_ref()), + Attribute::ContentEncoding => builder.header(CONTENT_ENCODING, v.as_ref()), + Attribute::ContentLanguage => builder.header(CONTENT_LANGUAGE, v.as_ref()), + Attribute::ContentType => { + has_content_type = true; + builder.header(CONTENT_TYPE, v.as_ref()) + } + Attribute::Metadata(k_suffix) => builder.header( + &format!("{}{}", USER_DEFINED_METADATA_HEADER_PREFIX, k_suffix), + v.as_ref(), + ), + }; + } + + if !has_content_type { + if let Some(value) = self.config.client_options.get_content_type(self.path) { + builder = builder.header(CONTENT_TYPE, value); + } + } + Self { builder, ..self } + } + + pub(crate) fn with_extensions(self, extensions: ::http::Extensions) -> Self { + let builder = self.builder.extensions(extensions); + Self { builder, ..self } + } + + pub(crate) fn with_payload(mut self, payload: PutPayload) -> Self { + if (!self.config.skip_signature && self.config.sign_payload) + || self.config.checksum.is_some() + { + let mut sha256 = Context::new(&digest::SHA256); + payload.iter().for_each(|x| sha256.update(x)); + let payload_sha256 = sha256.finish(); + + if let Some(Checksum::SHA256) = self.config.checksum { + self.builder = self + .builder + .header(SHA256_CHECKSUM, BASE64_STANDARD.encode(payload_sha256)); + } + self.payload_sha256 = Some(payload_sha256); + } + + let content_length = payload.content_length(); + self.builder = self.builder.header(CONTENT_LENGTH, content_length); + self.payload = Some(payload); + self + } + + pub(crate) async fn send(self) -> Result { + let credential = match self.use_session_creds { + true => self.config.get_session_credential().await?, + false => SessionCredential { + credential: self.config.get_credential().await?, + session_token: false, + config: self.config, + }, + }; + + let sha = self.payload_sha256.as_ref().map(|x| x.as_ref()); + + let path = self.path.as_ref(); + self.builder + .with_aws_sigv4(credential.authorizer(), sha) + .retryable(&self.config.retry_config) + .retry_on_conflict(self.retry_on_conflict) + .idempotent(self.idempotent) + .retry_error_body(self.retry_error_body) + .payload(self.payload) + .send() + .await + .map_err(|source| { + let path = path.into(); + RequestError::Retry { source, path } + }) + } + + pub(crate) async fn do_put(self) -> Result { + let response = self.send().await?; + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) + } +} + +#[derive(Debug)] +pub(crate) struct S3Client { + pub config: S3Config, + pub client: HttpClient, +} + +impl S3Client { + pub(crate) fn new(config: S3Config, client: HttpClient) -> Self { + Self { config, client } + } + + pub(crate) fn request<'a>(&'a self, method: Method, path: &'a Path) -> Request<'a> { + let url = self.config.path_url(path); + Request { + path, + builder: self.client.request(method, url), + payload: None, + payload_sha256: None, + config: &self.config, + use_session_creds: true, + idempotent: false, + retry_on_conflict: false, + retry_error_body: false, + } + } + + /// Make an S3 Delete Objects request + /// + /// Produces a vector of results, one for each path in the input vector. If + /// the delete was successful, the path is returned in the `Ok` variant. If + /// there was an error for a certain path, the error will be returned in the + /// vector. If there was an issue with making the overall request, an error + /// will be returned at the top level. + pub(crate) async fn bulk_delete_request(&self, paths: Vec) -> Result>> { + if paths.is_empty() { + return Ok(Vec::new()); + } + + let credential = self.config.get_session_credential().await?; + let url = format!("{}?delete", self.config.bucket_endpoint); + + let mut buffer = Vec::new(); + let mut writer = quick_xml::Writer::new(&mut buffer); + writer + .write_event(xml_events::Event::Start( + xml_events::BytesStart::new("Delete") + .with_attributes([("xmlns", "http://s3.amazonaws.com/doc/2006-03-01/")]), + )) + .unwrap(); + for path in &paths { + // {path} + writer + .write_event(xml_events::Event::Start(xml_events::BytesStart::new( + "Object", + ))) + .unwrap(); + writer + .write_event(xml_events::Event::Start(xml_events::BytesStart::new("Key"))) + .unwrap(); + writer + .write_event(xml_events::Event::Text(xml_events::BytesText::new( + path.as_ref(), + ))) + .map_err(|err| crate::Error::Generic { + store: STORE, + source: Box::new(err), + })?; + writer + .write_event(xml_events::Event::End(xml_events::BytesEnd::new("Key"))) + .unwrap(); + writer + .write_event(xml_events::Event::End(xml_events::BytesEnd::new("Object"))) + .unwrap(); + } + writer + .write_event(xml_events::Event::End(xml_events::BytesEnd::new("Delete"))) + .unwrap(); + + let body = Bytes::from(buffer); + + let mut builder = self.client.request(Method::POST, url); + + let digest = digest::digest(&digest::SHA256, &body); + builder = builder.header(SHA256_CHECKSUM, BASE64_STANDARD.encode(digest)); + + // S3 *requires* DeleteObjects to include a Content-MD5 header: + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html + // > "The Content-MD5 request header is required for all Multi-Object Delete requests" + // Some platforms, like MinIO, enforce this requirement and fail requests without the header. + let mut hasher = Md5::new(); + hasher.update(&body); + builder = builder.header("Content-MD5", BASE64_STANDARD.encode(hasher.finalize())); + + let response = builder + .header(CONTENT_TYPE, "application/xml") + .body(body) + .with_aws_sigv4(credential.authorizer(), Some(digest.as_ref())) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| Error::DeleteObjectsRequest { source })? + .into_body() + .bytes() + .await + .map_err(|source| Error::DeleteObjectsResponse { source })?; + + let response: BatchDeleteResponse = + quick_xml::de::from_reader(response.reader()).map_err(|err| { + Error::InvalidDeleteObjectsResponse { + source: Box::new(err), + } + })?; + + // Assume all were ok, then fill in errors. This guarantees output order + // matches input order. + let mut results: Vec> = paths.iter().cloned().map(Ok).collect(); + for content in response.content.into_iter() { + if let DeleteObjectResult::Error(error) = content { + let path = + Path::parse(&error.key).map_err(|err| Error::InvalidDeleteObjectsResponse { + source: Box::new(err), + })?; + let i = paths.iter().find_position(|&p| p == &path).unwrap().0; + results[i] = Err(Error::from(error).into()); + } + } + + Ok(results) + } + + /// Make an S3 Copy request + pub(crate) fn copy_request<'a>(&'a self, from: &Path, to: &'a Path) -> Request<'a> { + let source = format!("{}/{}", self.config.bucket, encode_path(from)); + + let mut copy_source_encryption_headers = HeaderMap::new(); + if let Some(customer_algorithm) = self + .config + .encryption_headers + .0 + .get("x-amz-server-side-encryption-customer-algorithm") + { + copy_source_encryption_headers.insert( + "x-amz-copy-source-server-side-encryption-customer-algorithm", + customer_algorithm.clone(), + ); + } + if let Some(customer_key) = self + .config + .encryption_headers + .0 + .get("x-amz-server-side-encryption-customer-key") + { + copy_source_encryption_headers.insert( + "x-amz-copy-source-server-side-encryption-customer-key", + customer_key.clone(), + ); + } + if let Some(customer_key_md5) = self + .config + .encryption_headers + .0 + .get("x-amz-server-side-encryption-customer-key-MD5") + { + copy_source_encryption_headers.insert( + "x-amz-copy-source-server-side-encryption-customer-key-MD5", + customer_key_md5.clone(), + ); + } + + self.request(Method::PUT, to) + .idempotent(true) + .retry_error_body(true) + .header(©_SOURCE_HEADER, &source) + .headers(self.config.encryption_headers.clone().into()) + .headers(copy_source_encryption_headers) + .with_session_creds(false) + } + + pub(crate) async fn create_multipart( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result { + let PutMultipartOpts { + tags, + attributes, + extensions, + } = opts; + + let mut request = self.request(Method::POST, location); + if let Some(algorithm) = self.config.checksum { + match algorithm { + Checksum::SHA256 => { + request = request.header(ALGORITHM, "SHA256"); + } + } + } + let response = request + .query(&[("uploads", "")]) + .with_encryption_headers() + .with_attributes(attributes) + .with_tags(tags) + .with_extensions(extensions) + .idempotent(true) + .send() + .await? + .into_body() + .bytes() + .await + .map_err(|source| Error::CreateMultipartResponseBody { source })?; + + let response: InitiateMultipartUploadResult = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; + + Ok(response.upload_id) + } + + pub(crate) async fn put_part( + &self, + path: &Path, + upload_id: &MultipartId, + part_idx: usize, + data: PutPartPayload<'_>, + ) -> Result { + let is_copy = matches!(data, PutPartPayload::Copy(_)); + let part = (part_idx + 1).to_string(); + + let mut request = self + .request(Method::PUT, path) + .query(&[("partNumber", &part), ("uploadId", upload_id)]) + .idempotent(true); + + request = match data { + PutPartPayload::Part(payload) => request.with_payload(payload), + PutPartPayload::Copy(path) => request.header( + "x-amz-copy-source", + &format!("{}/{}", self.config.bucket, encode_path(path)), + ), + }; + + if self + .config + .encryption_headers + .0 + .contains_key("x-amz-server-side-encryption-customer-algorithm") + { + // If SSE-C is used, we must include the encryption headers in every upload request. + request = request.with_encryption_headers(); + } + let (parts, body) = request.send().await?.into_parts(); + let checksum_sha256 = parts + .headers + .get(SHA256_CHECKSUM) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + + let e_tag = match is_copy { + false => get_etag(&parts.headers).map_err(|source| Error::Metadata { source })?, + true => { + let response = body + .bytes() + .await + .map_err(|source| Error::CreateMultipartResponseBody { source })?; + let response: CopyPartResult = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; + response.e_tag + } + }; + + let content_id = if self.config.checksum == Some(Checksum::SHA256) { + let meta = PartMetadata { + e_tag, + checksum_sha256, + }; + quick_xml::se::to_string(&meta).unwrap() + } else { + e_tag + }; + + Ok(PartId { content_id }) + } + + pub(crate) async fn abort_multipart(&self, location: &Path, upload_id: &str) -> Result<()> { + self.request(Method::DELETE, location) + .query(&[("uploadId", upload_id)]) + .with_encryption_headers() + .send() + .await?; + + Ok(()) + } + + pub(crate) async fn complete_multipart( + &self, + location: &Path, + upload_id: &str, + parts: Vec, + mode: CompleteMultipartMode, + ) -> Result { + let parts = if parts.is_empty() { + // If no parts were uploaded, upload an empty part + // otherwise the completion request will fail + let part = self + .put_part( + location, + &upload_id.to_string(), + 0, + PutPartPayload::default(), + ) + .await?; + vec![part] + } else { + parts + }; + let request = CompleteMultipartUpload::from(parts); + let body = quick_xml::se::to_string(&request).unwrap(); + + let credential = self.config.get_session_credential().await?; + let url = self.config.path_url(location); + + let request = self + .client + .post(url) + .query(&[("uploadId", upload_id)]) + .body(body) + .with_aws_sigv4(credential.authorizer(), None); + + let request = match mode { + CompleteMultipartMode::Overwrite => request, + CompleteMultipartMode::Create => request.header("If-None-Match", "*"), + }; + + let response = request + .retryable(&self.config.retry_config) + .idempotent(true) + .retry_error_body(true) + .send() + .await + .map_err(|source| Error::CompleteMultipartRequest { + source, + path: location.as_ref().to_string(), + })?; + + let version = get_version(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?; + + let data = response + .into_body() + .bytes() + .await + .map_err(|source| Error::CompleteMultipartResponseBody { source })?; + + let response: CompleteMultipartUploadResult = quick_xml::de::from_reader(data.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; + + Ok(PutResult { + e_tag: Some(response.e_tag), + version, + }) + } + + #[cfg(test)] + pub(crate) async fn get_object_tagging(&self, path: &Path) -> Result { + let credential = self.config.get_session_credential().await?; + let url = format!("{}?tagging", self.config.path_url(path)); + let response = self + .client + .request(Method::GET, url) + .with_aws_sigv4(credential.authorizer(), None) + .send_retry(&self.config.retry_config) + .await + .map_err(|e| e.error(STORE, path.to_string()))?; + Ok(response) + } +} + +#[async_trait] +impl GetClient for S3Client { + const STORE: &'static str = STORE; + + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: false, + last_modified_required: false, + version_header: Some(VERSION_HEADER), + user_defined_metadata_prefix: Some(USER_DEFINED_METADATA_HEADER_PREFIX), + }; + + /// Make an S3 GET request + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let credential = self.config.get_session_credential().await?; + let url = self.config.path_url(path); + let method = match options.head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut builder = self.client.request(method, url); + if self + .config + .encryption_headers + .0 + .contains_key("x-amz-server-side-encryption-customer-algorithm") + { + builder = builder.headers(self.config.encryption_headers.clone().into()); + } + + if let Some(v) = &options.version { + builder = builder.query(&[("versionId", v)]) + } + + let response = builder + .with_get_options(options) + .with_aws_sigv4(credential.authorizer(), None) + .send_retry(&self.config.retry_config) + .await + .map_err(|e| e.error(STORE, path.to_string()))?; + + Ok(response) + } +} + +#[async_trait] +impl ListClient for Arc { + /// Make an S3 List request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + token: Option<&str>, + offset: Option<&str>, + ) -> Result<(ListResult, Option)> { + let credential = self.config.get_session_credential().await?; + let url = self.config.bucket_endpoint.clone(); + + let mut query = Vec::with_capacity(4); + + if let Some(token) = token { + query.push(("continuation-token", token)) + } + + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + query.push(("list-type", "2")); + + if let Some(prefix) = prefix { + query.push(("prefix", prefix)) + } + + if let Some(offset) = offset { + query.push(("start-after", offset)) + } + + let response = self + .client + .request(Method::GET, &url) + .query(&query) + .with_aws_sigv4(credential.authorizer(), None) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| Error::ListRequest { source })? + .into_body() + .bytes() + .await + .map_err(|source| Error::ListResponseBody { source })?; + + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidListResponse { source })?; + + let token = response.next_continuation_token.take(); + + Ok((response.try_into()?, token)) + } +} + +fn encode_path(path: &Path) -> PercentEncode<'_> { + utf8_percent_encode(path.as_ref(), &STRICT_PATH_ENCODE_SET) +} diff --git a/src/aws/credential.rs b/src/aws/credential.rs new file mode 100644 index 0000000..1b62842 --- /dev/null +++ b/src/aws/credential.rs @@ -0,0 +1,1159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::{AwsCredentialProvider, STORE, STRICT_ENCODE_SET, STRICT_PATH_ENCODE_SET}; +use crate::client::builder::HttpRequestBuilder; +use crate::client::retry::RetryExt; +use crate::client::token::{TemporaryToken, TokenCache}; +use crate::client::{HttpClient, HttpError, HttpRequest, TokenProvider}; +use crate::util::{hex_digest, hex_encode, hmac_sha256}; +use crate::{CredentialProvider, Result, RetryConfig}; +use async_trait::async_trait; +use bytes::Buf; +use chrono::{DateTime, Utc}; +use http::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION}; +use http::{Method, StatusCode}; +use percent_encoding::utf8_percent_encode; +use serde::Deserialize; +use std::collections::BTreeMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::warn; +use url::Url; + +#[derive(Debug, thiserror::Error)] +#[allow(clippy::enum_variant_names)] +enum Error { + #[error("Error performing CreateSession request: {source}")] + CreateSessionRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting CreateSession response: {source}")] + CreateSessionResponse { source: HttpError }, + + #[error("Invalid CreateSessionOutput response: {source}")] + CreateSessionOutput { source: quick_xml::DeError }, +} + +impl From for crate::Error { + fn from(value: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(value), + } + } +} + +type StdError = Box; + +/// SHA256 hash of empty string +static EMPTY_SHA256_HASH: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; +static UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD"; +static STREAMING_PAYLOAD: &str = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"; + +/// A set of AWS security credentials +#[derive(Debug, Eq, PartialEq)] +pub struct AwsCredential { + /// AWS_ACCESS_KEY_ID + pub key_id: String, + /// AWS_SECRET_ACCESS_KEY + pub secret_key: String, + /// AWS_SESSION_TOKEN + pub token: Option, +} + +impl AwsCredential { + /// Signs a string + /// + /// + fn sign(&self, to_sign: &str, date: DateTime, region: &str, service: &str) -> String { + let date_string = date.format("%Y%m%d").to_string(); + let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string); + let region_hmac = hmac_sha256(date_hmac, region); + let service_hmac = hmac_sha256(region_hmac, service); + let signing_hmac = hmac_sha256(service_hmac, b"aws4_request"); + hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref()) + } +} + +/// Authorize a [`HttpRequest`] with an [`AwsCredential`] using [AWS SigV4] +/// +/// [AWS SigV4]: https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html +#[derive(Debug)] +pub struct AwsAuthorizer<'a> { + date: Option>, + credential: &'a AwsCredential, + service: &'a str, + region: &'a str, + token_header: Option, + sign_payload: bool, + request_payer: bool, +} + +static DATE_HEADER: HeaderName = HeaderName::from_static("x-amz-date"); +static HASH_HEADER: HeaderName = HeaderName::from_static("x-amz-content-sha256"); +static TOKEN_HEADER: HeaderName = HeaderName::from_static("x-amz-security-token"); +static REQUEST_PAYER_HEADER: HeaderName = HeaderName::from_static("x-amz-request-payer"); +static REQUEST_PAYER_HEADER_VALUE: HeaderValue = HeaderValue::from_static("requester"); +const ALGORITHM: &str = "AWS4-HMAC-SHA256"; + +impl<'a> AwsAuthorizer<'a> { + /// Create a new [`AwsAuthorizer`] + pub fn new(credential: &'a AwsCredential, service: &'a str, region: &'a str) -> Self { + Self { + credential, + service, + region, + date: None, + sign_payload: true, + token_header: None, + request_payer: false, + } + } + + /// Controls whether this [`AwsAuthorizer`] will attempt to sign the request payload, + /// the default is `true` + pub fn with_sign_payload(mut self, signed: bool) -> Self { + self.sign_payload = signed; + self + } + + /// Overrides the header name for security tokens, defaults to `x-amz-security-token` + pub(crate) fn with_token_header(mut self, header: HeaderName) -> Self { + self.token_header = Some(header); + self + } + + /// Set whether to include requester pays headers + /// + /// + pub fn with_request_payer(mut self, request_payer: bool) -> Self { + self.request_payer = request_payer; + self + } + + /// Authorize `request` with an optional pre-calculated SHA256 digest by attaching + /// the relevant [AWS SigV4] headers + /// + /// # Payload Signature + /// + /// AWS SigV4 requests must contain the `x-amz-content-sha256` header, it is set as follows: + /// + /// * If not configured to sign payloads, it is set to `UNSIGNED-PAYLOAD` + /// * If a `pre_calculated_digest` is provided, it is set to the hex encoding of it + /// * If it is a streaming request, it is set to `STREAMING-AWS4-HMAC-SHA256-PAYLOAD` + /// * Otherwise it is set to the hex encoded SHA256 of the request body + /// + /// [AWS SigV4]: https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html + pub fn authorize(&self, request: &mut HttpRequest, pre_calculated_digest: Option<&[u8]>) { + let url = Url::parse(&request.uri().to_string()).unwrap(); + + if let Some(ref token) = self.credential.token { + let token_val = HeaderValue::from_str(token).unwrap(); + let header = self.token_header.as_ref().unwrap_or(&TOKEN_HEADER); + request.headers_mut().insert(header, token_val); + } + + let host = &url[url::Position::BeforeHost..url::Position::AfterPort]; + let host_val = HeaderValue::from_str(host).unwrap(); + request.headers_mut().insert("host", host_val); + + let date = self.date.unwrap_or_else(Utc::now); + let date_str = date.format("%Y%m%dT%H%M%SZ").to_string(); + let date_val = HeaderValue::from_str(&date_str).unwrap(); + request.headers_mut().insert(&DATE_HEADER, date_val); + + let digest = match self.sign_payload { + false => UNSIGNED_PAYLOAD.to_string(), + true => match pre_calculated_digest { + Some(digest) => hex_encode(digest), + None => match request.body().is_empty() { + true => EMPTY_SHA256_HASH.to_string(), + false => match request.body().as_bytes() { + Some(bytes) => hex_digest(bytes), + None => STREAMING_PAYLOAD.to_string(), + }, + }, + }, + }; + + let header_digest = HeaderValue::from_str(&digest).unwrap(); + request.headers_mut().insert(&HASH_HEADER, header_digest); + + if self.request_payer { + // For DELETE, GET, HEAD, POST, and PUT requests, include x-amz-request-payer : + // requester in the header + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ObjectsinRequesterPaysBuckets.html + request + .headers_mut() + .insert(&REQUEST_PAYER_HEADER, REQUEST_PAYER_HEADER_VALUE.clone()); + } + + let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); + + let scope = self.scope(date); + + let string_to_sign = self.string_to_sign( + date, + &scope, + request.method(), + &url, + &canonical_headers, + &signed_headers, + &digest, + ); + + // sign the string + let signature = self + .credential + .sign(&string_to_sign, date, self.region, self.service); + + // build the actual auth header + let authorisation = format!( + "{} Credential={}/{}, SignedHeaders={}, Signature={}", + ALGORITHM, self.credential.key_id, scope, signed_headers, signature + ); + + let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); + request + .headers_mut() + .insert(&AUTHORIZATION, authorization_val); + } + + pub(crate) fn sign(&self, method: Method, url: &mut Url, expires_in: Duration) { + let date = self.date.unwrap_or_else(Utc::now); + let scope = self.scope(date); + + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + url.query_pairs_mut() + .append_pair("X-Amz-Algorithm", ALGORITHM) + .append_pair( + "X-Amz-Credential", + &format!("{}/{}", self.credential.key_id, scope), + ) + .append_pair("X-Amz-Date", &date.format("%Y%m%dT%H%M%SZ").to_string()) + .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string()) + .append_pair("X-Amz-SignedHeaders", "host"); + + if self.request_payer { + // For signed URLs, include x-amz-request-payer=requester in the request + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ObjectsinRequesterPaysBuckets.html + url.query_pairs_mut() + .append_pair("x-amz-request-payer", "requester"); + } + + // For S3, you must include the X-Amz-Security-Token query parameter in the URL if + // using credentials sourced from the STS service. + if let Some(ref token) = self.credential.token { + url.query_pairs_mut() + .append_pair("X-Amz-Security-Token", token); + } + + // We don't have a payload; the user is going to send the payload directly themselves. + let digest = UNSIGNED_PAYLOAD; + + let host = &url[url::Position::BeforeHost..url::Position::AfterPort].to_string(); + let mut headers = HeaderMap::new(); + let host_val = HeaderValue::from_str(host).unwrap(); + headers.insert("host", host_val); + + let (signed_headers, canonical_headers) = canonicalize_headers(&headers); + + let string_to_sign = self.string_to_sign( + date, + &scope, + &method, + url, + &canonical_headers, + &signed_headers, + digest, + ); + + let signature = self + .credential + .sign(&string_to_sign, date, self.region, self.service); + + url.query_pairs_mut() + .append_pair("X-Amz-Signature", &signature); + } + + #[allow(clippy::too_many_arguments)] + fn string_to_sign( + &self, + date: DateTime, + scope: &str, + request_method: &Method, + url: &Url, + canonical_headers: &str, + signed_headers: &str, + digest: &str, + ) -> String { + // Each path segment must be URI-encoded twice (except for Amazon S3 which only gets + // URI-encoded once). + // see https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + let canonical_uri = match self.service { + "s3" => url.path().to_string(), + _ => utf8_percent_encode(url.path(), &STRICT_PATH_ENCODE_SET).to_string(), + }; + + let canonical_query = canonicalize_query(url); + + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + request_method.as_str(), + canonical_uri, + canonical_query, + canonical_headers, + signed_headers, + digest + ); + + let hashed_canonical_request = hex_digest(canonical_request.as_bytes()); + + format!( + "{}\n{}\n{}\n{}", + ALGORITHM, + date.format("%Y%m%dT%H%M%SZ"), + scope, + hashed_canonical_request + ) + } + + fn scope(&self, date: DateTime) -> String { + format!( + "{}/{}/{}/aws4_request", + date.format("%Y%m%d"), + self.region, + self.service + ) + } +} + +pub(crate) trait CredentialExt { + /// Sign a request + fn with_aws_sigv4( + self, + authorizer: Option>, + payload_sha256: Option<&[u8]>, + ) -> Self; +} + +impl CredentialExt for HttpRequestBuilder { + fn with_aws_sigv4( + self, + authorizer: Option>, + payload_sha256: Option<&[u8]>, + ) -> Self { + match authorizer { + Some(authorizer) => { + let (client, request) = self.into_parts(); + let mut request = request.expect("request valid"); + authorizer.authorize(&mut request, payload_sha256); + + Self::from_parts(client, request) + } + None => self, + } + } +} + +/// Canonicalizes query parameters into the AWS canonical form +/// +/// +fn canonicalize_query(url: &Url) -> String { + use std::fmt::Write; + + let capacity = match url.query() { + Some(q) if !q.is_empty() => q.len(), + _ => return String::new(), + }; + let mut encoded = String::with_capacity(capacity + 1); + + let mut headers = url.query_pairs().collect::>(); + headers.sort_unstable_by(|(a, _), (b, _)| a.cmp(b)); + + let mut first = true; + for (k, v) in headers { + if !first { + encoded.push('&'); + } + first = false; + let _ = write!( + encoded, + "{}={}", + utf8_percent_encode(k.as_ref(), &STRICT_ENCODE_SET), + utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET) + ); + } + encoded +} + +/// Canonicalizes headers into the AWS Canonical Form. +/// +/// +fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) { + let mut headers = BTreeMap::<&str, Vec<&str>>::new(); + let mut value_count = 0; + let mut value_bytes = 0; + let mut key_bytes = 0; + + for (key, value) in header_map { + let key = key.as_str(); + if ["authorization", "content-length", "user-agent"].contains(&key) { + continue; + } + + let value = std::str::from_utf8(value.as_bytes()).unwrap(); + key_bytes += key.len(); + value_bytes += value.len(); + value_count += 1; + headers.entry(key).or_default().push(value); + } + + let mut signed_headers = String::with_capacity(key_bytes + headers.len()); + let mut canonical_headers = + String::with_capacity(key_bytes + value_bytes + headers.len() + value_count); + + for (header_idx, (name, values)) in headers.into_iter().enumerate() { + if header_idx != 0 { + signed_headers.push(';'); + } + + signed_headers.push_str(name); + canonical_headers.push_str(name); + canonical_headers.push(':'); + for (value_idx, value) in values.into_iter().enumerate() { + if value_idx != 0 { + canonical_headers.push(','); + } + canonical_headers.push_str(value.trim()); + } + canonical_headers.push('\n'); + } + + (signed_headers, canonical_headers) +} + +/// Credentials sourced from the instance metadata service +/// +/// +#[derive(Debug)] +pub(crate) struct InstanceCredentialProvider { + pub imdsv1_fallback: bool, + pub metadata_endpoint: String, +} + +#[async_trait] +impl TokenProvider for InstanceCredentialProvider { + type Credential = AwsCredential; + + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> Result>> { + instance_creds(client, retry, &self.metadata_endpoint, self.imdsv1_fallback) + .await + .map_err(|source| crate::Error::Generic { + store: STORE, + source, + }) + } +} + +/// Credentials sourced using AssumeRoleWithWebIdentity +/// +/// +#[derive(Debug)] +pub(crate) struct WebIdentityProvider { + pub token_path: String, + pub role_arn: String, + pub session_name: String, + pub endpoint: String, +} + +#[async_trait] +impl TokenProvider for WebIdentityProvider { + type Credential = AwsCredential; + + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> Result>> { + web_identity( + client, + retry, + &self.token_path, + &self.role_arn, + &self.session_name, + &self.endpoint, + ) + .await + .map_err(|source| crate::Error::Generic { + store: STORE, + source, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct InstanceCredentials { + access_key_id: String, + secret_access_key: String, + token: String, + expiration: DateTime, +} + +impl From for AwsCredential { + fn from(s: InstanceCredentials) -> Self { + Self { + key_id: s.access_key_id, + secret_key: s.secret_access_key, + token: Some(s.token), + } + } +} + +/// +async fn instance_creds( + client: &HttpClient, + retry_config: &RetryConfig, + endpoint: &str, + imdsv1_fallback: bool, +) -> Result>, StdError> { + const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials"; + const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token"; + + let token_url = format!("{endpoint}/latest/api/token"); + + let token_result = client + .request(Method::PUT, token_url) + .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL + .retryable(retry_config) + .idempotent(true) + .send() + .await; + + let token = match token_result { + Ok(t) => Some(t.into_body().text().await?), + Err(e) if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) => { + warn!("received 403 from metadata endpoint, falling back to IMDSv1"); + None + } + Err(e) => return Err(e.into()), + }; + + let role_url = format!("{endpoint}/{CREDENTIALS_PATH}/"); + let mut role_request = client.request(Method::GET, role_url); + + if let Some(token) = &token { + role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); + } + + let role = role_request + .send_retry(retry_config) + .await? + .into_body() + .text() + .await?; + + let creds_url = format!("{endpoint}/{CREDENTIALS_PATH}/{role}"); + let mut creds_request = client.request(Method::GET, creds_url); + if let Some(token) = &token { + creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); + } + + let creds: InstanceCredentials = creds_request + .send_retry(retry_config) + .await? + .into_body() + .json() + .await?; + + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Some(Instant::now() + ttl), + }) +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleResponse { + assume_role_with_web_identity_result: AssumeRoleResult, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleResult { + credentials: SessionCredentials, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct SessionCredentials { + session_token: String, + secret_access_key: String, + access_key_id: String, + expiration: DateTime, +} + +impl From for AwsCredential { + fn from(s: SessionCredentials) -> Self { + Self { + key_id: s.access_key_id, + secret_key: s.secret_access_key, + token: Some(s.session_token), + } + } +} + +/// +async fn web_identity( + client: &HttpClient, + retry_config: &RetryConfig, + token_path: &str, + role_arn: &str, + session_name: &str, + endpoint: &str, +) -> Result>, StdError> { + let token = std::fs::read_to_string(token_path) + .map_err(|e| format!("Failed to read token file '{token_path}': {e}"))?; + + let bytes = client + .post(endpoint) + .query(&[ + ("Action", "AssumeRoleWithWebIdentity"), + ("DurationSeconds", "3600"), + ("RoleArn", role_arn), + ("RoleSessionName", session_name), + ("Version", "2011-06-15"), + ("WebIdentityToken", &token), + ]) + .retryable(retry_config) + .idempotent(true) + .sensitive(true) + .send() + .await? + .into_body() + .bytes() + .await?; + + let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader()) + .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {e}"))?; + + let creds = resp.assume_role_with_web_identity_result.credentials; + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Some(Instant::now() + ttl), + }) +} + +/// Credentials sourced from a task IAM role +/// +/// +#[derive(Debug)] +pub(crate) struct TaskCredentialProvider { + pub url: String, + pub retry: RetryConfig, + pub client: HttpClient, + pub cache: TokenCache>, +} + +#[async_trait] +impl CredentialProvider for TaskCredentialProvider { + type Credential = AwsCredential; + + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| task_credential(&self.client, &self.retry, &self.url)) + .await + .map_err(|source| crate::Error::Generic { + store: STORE, + source, + }) + } +} + +/// +async fn task_credential( + client: &HttpClient, + retry: &RetryConfig, + url: &str, +) -> Result>, StdError> { + let creds: InstanceCredentials = client + .get(url) + .send_retry(retry) + .await? + .into_body() + .json() + .await?; + + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Some(Instant::now() + ttl), + }) +} + +/// A session provider as used by S3 Express One Zone +/// +/// +#[derive(Debug)] +pub(crate) struct SessionProvider { + pub endpoint: String, + pub region: String, + pub credentials: AwsCredentialProvider, +} + +#[async_trait] +impl TokenProvider for SessionProvider { + type Credential = AwsCredential; + + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> Result>> { + let creds = self.credentials.get_credential().await?; + let authorizer = AwsAuthorizer::new(&creds, "s3", &self.region); + + let bytes = client + .get(format!("{}?session", self.endpoint)) + .with_aws_sigv4(Some(authorizer), None) + .send_retry(retry) + .await + .map_err(|source| Error::CreateSessionRequest { source })? + .into_body() + .bytes() + .await + .map_err(|source| Error::CreateSessionResponse { source })?; + + let resp: CreateSessionOutput = quick_xml::de::from_reader(bytes.reader()) + .map_err(|source| Error::CreateSessionOutput { source })?; + + let creds = resp.credentials; + Ok(TemporaryToken { + token: Arc::new(creds.into()), + // Credentials last 5 minutes - https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateSession.html + expiry: Some(Instant::now() + Duration::from_secs(5 * 60)), + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct CreateSessionOutput { + credentials: SessionCredentials, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::mock_server::MockServer; + use crate::client::HttpClient; + use http::Response; + use reqwest::{Client, Method}; + use std::env; + + // Test generated using https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html + #[test] + fn test_sign_with_signed_payload() { + let client = HttpClient::new(Client::new()); + + // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + // method = 'GET' + // service = 'ec2' + // host = 'ec2.amazonaws.com' + // region = 'us-east-1' + // endpoint = 'https://ec2.amazonaws.com' + // request_parameters = '' + let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "https://ec2.amazon.com/") + .into_parts() + .1 + .unwrap(); + + let signer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "ec2", + region: "us-east-1", + sign_payload: true, + token_header: None, + request_payer: false, + }; + + signer.authorize(&mut request, None); + assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4") + } + + #[test] + fn test_sign_with_signed_payload_request_payer() { + let client = HttpClient::new(Client::new()); + + // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + // method = 'GET' + // service = 'ec2' + // host = 'ec2.amazonaws.com' + // region = 'us-east-1' + // endpoint = 'https://ec2.amazonaws.com' + // request_parameters = '' + let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "https://ec2.amazon.com/") + .into_parts() + .1 + .unwrap(); + + let signer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "ec2", + region: "us-east-1", + sign_payload: true, + token_header: None, + request_payer: true, + }; + + signer.authorize(&mut request, None); + assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-request-payer, Signature=7030625a9e9b57ed2a40e63d749f4a4b7714b6e15004cab026152f870dd8565d") + } + + #[test] + fn test_sign_with_unsigned_payload() { + let client = HttpClient::new(Client::new()); + + // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + // method = 'GET' + // service = 'ec2' + // host = 'ec2.amazonaws.com' + // region = 'us-east-1' + // endpoint = 'https://ec2.amazonaws.com' + // request_parameters = '' + let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "https://ec2.amazon.com/") + .into_parts() + .1 + .unwrap(); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "ec2", + region: "us-east-1", + token_header: None, + sign_payload: false, + request_payer: false, + }; + + authorizer.authorize(&mut request, None); + assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699"); + } + + #[test] + fn signed_get_url() { + // Values from https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z") + .unwrap() + .with_timezone(&Utc); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "s3", + region: "us-east-1", + token_header: None, + sign_payload: false, + request_payer: false, + }; + + let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400)); + + assert_eq!( + url, + Url::parse( + "https://examplebucket.s3.amazonaws.com/test.txt?\ + X-Amz-Algorithm=AWS4-HMAC-SHA256&\ + X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\ + X-Amz-Date=20130524T000000Z&\ + X-Amz-Expires=86400&\ + X-Amz-SignedHeaders=host&\ + X-Amz-Signature=aeeed9bbccd4d02ee5c0109b86d86835f995330da4c265957d157751f604d404" + ) + .unwrap() + ); + } + + #[test] + fn signed_get_url_request_payer() { + // Values from https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z") + .unwrap() + .with_timezone(&Utc); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "s3", + region: "us-east-1", + token_header: None, + sign_payload: false, + request_payer: true, + }; + + let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400)); + + assert_eq!( + url, + Url::parse( + "https://examplebucket.s3.amazonaws.com/test.txt?\ + X-Amz-Algorithm=AWS4-HMAC-SHA256&\ + X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\ + X-Amz-Date=20130524T000000Z&\ + X-Amz-Expires=86400&\ + X-Amz-SignedHeaders=host&\ + x-amz-request-payer=requester&\ + X-Amz-Signature=9ad7c781cc30121f199b47d35ed3528473e4375b63c5d91cd87c927803e4e00a" + ) + .unwrap() + ); + } + + #[test] + fn test_sign_port() { + let client = HttpClient::new(Client::new()); + + let credential = AwsCredential { + key_id: "H20ABqCkLZID4rLe".to_string(), + secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "http://localhost:9000/tsm-schemas") + .query(&[ + ("delimiter", "/"), + ("encoding-type", "url"), + ("list-type", "2"), + ("prefix", ""), + ]) + .into_parts() + .1 + .unwrap(); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "s3", + region: "us-east-1", + token_header: None, + sign_payload: true, + request_payer: false, + }; + + authorizer.authorize(&mut request, None); + assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d") + } + + #[tokio::test] + async fn test_instance_metadata() { + if env::var("TEST_INTEGRATION").is_err() { + eprintln!("skipping AWS integration test"); + return; + } + + // For example https://github.com/aws/amazon-ec2-metadata-mock + let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap(); + let client = HttpClient::new(Client::new()); + let retry_config = RetryConfig::default(); + + // Verify only allows IMDSv2 + let (client, req) = client + .request(Method::GET, format!("{endpoint}/latest/meta-data/ami-id")) + .into_parts(); + + let resp = client.execute(req.unwrap()).await.unwrap(); + + assert_eq!( + resp.status(), + StatusCode::UNAUTHORIZED, + "Ensure metadata endpoint is set to only allow IMDSv2" + ); + + let creds = instance_creds(&client, &retry_config, &endpoint, false) + .await + .unwrap(); + + let id = &creds.token.key_id; + let secret = &creds.token.secret_key; + let token = creds.token.token.as_ref().unwrap(); + + assert!(!id.is_empty()); + assert!(!secret.is_empty()); + assert!(!token.is_empty()) + } + + #[tokio::test] + async fn test_mock() { + let server = MockServer::new().await; + + const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token"; + + let secret_access_key = "SECRET"; + let access_key_id = "KEYID"; + let token = "TOKEN"; + + let endpoint = server.url(); + let client = HttpClient::new(Client::new()); + let retry_config = RetryConfig::default(); + + // Test IMDSv2 + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/api/token"); + assert_eq!(req.method(), &Method::PUT); + Response::new("cupcakes".to_string()) + }); + server.push_fn(|req| { + assert_eq!( + req.uri().path(), + "/latest/meta-data/iam/security-credentials/" + ); + assert_eq!(req.method(), &Method::GET); + let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); + assert_eq!(t, "cupcakes"); + Response::new("myrole".to_string()) + }); + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); + assert_eq!(req.method(), &Method::GET); + let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); + assert_eq!(t, "cupcakes"); + Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string()) + }); + + let creds = instance_creds(&client, &retry_config, endpoint, true) + .await + .unwrap(); + + assert_eq!(creds.token.token.as_deref().unwrap(), token); + assert_eq!(&creds.token.key_id, access_key_id); + assert_eq!(&creds.token.secret_key, secret_access_key); + + // Test IMDSv1 fallback + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/api/token"); + assert_eq!(req.method(), &Method::PUT); + Response::builder() + .status(StatusCode::FORBIDDEN) + .body(String::new()) + .unwrap() + }); + server.push_fn(|req| { + assert_eq!( + req.uri().path(), + "/latest/meta-data/iam/security-credentials/" + ); + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get(IMDSV2_HEADER).is_none()); + Response::new("myrole".to_string()) + }); + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get(IMDSV2_HEADER).is_none()); + Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string()) + }); + + let creds = instance_creds(&client, &retry_config, endpoint, true) + .await + .unwrap(); + + assert_eq!(creds.token.token.as_deref().unwrap(), token); + assert_eq!(&creds.token.key_id, access_key_id); + assert_eq!(&creds.token.secret_key, secret_access_key); + + // Test IMDSv1 fallback disabled + server.push( + Response::builder() + .status(StatusCode::FORBIDDEN) + .body(String::new()) + .unwrap(), + ); + + // Should fail + instance_creds(&client, &retry_config, endpoint, false) + .await + .unwrap_err(); + } +} diff --git a/src/aws/dynamo.rs b/src/aws/dynamo.rs new file mode 100644 index 0000000..73380aa --- /dev/null +++ b/src/aws/dynamo.rs @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A DynamoDB based lock system + +use std::borrow::Cow; +use std::collections::HashMap; +use std::future::Future; +use std::time::{Duration, Instant}; + +use chrono::Utc; +use http::{Method, StatusCode}; +use serde::ser::SerializeMap; +use serde::{Deserialize, Serialize, Serializer}; + +use crate::aws::client::S3Client; +use crate::aws::credential::CredentialExt; +use crate::aws::{AwsAuthorizer, AwsCredential}; +use crate::client::get::GetClientExt; +use crate::client::retry::RetryExt; +use crate::client::retry::{RequestError, RetryError}; +use crate::path::Path; +use crate::{Error, GetOptions, Result}; + +/// The exception returned by DynamoDB on conflict +const CONFLICT: &str = "ConditionalCheckFailedException"; + +const STORE: &str = "DynamoDB"; + +/// A DynamoDB-based commit protocol, used to provide conditional write support for S3 +/// +/// ## Limitations +/// +/// Only conditional operations, e.g. `copy_if_not_exists` will be synchronized, and can +/// therefore race with non-conditional operations, e.g. `put`, `copy`, `delete`, or +/// conditional operations performed by writers not configured to synchronize with DynamoDB. +/// +/// Workloads making use of this mechanism **must** ensure: +/// +/// * Conditional and non-conditional operations are not performed on the same paths +/// * Conditional operations are only performed via similarly configured clients +/// +/// Additionally as the locking mechanism relies on timeouts to detect stale locks, +/// performance will be poor for systems that frequently delete and then create +/// objects at the same path, instead being optimised for systems that primarily create +/// files with paths never used before, or perform conditional updates to existing files +/// +/// ## Commit Protocol +/// +/// The DynamoDB schema is as follows: +/// +/// * A string partition key named `"path"` +/// * A string sort key named `"etag"` +/// * A numeric [TTL] attribute named `"ttl"` +/// * A numeric attribute named `"generation"` +/// * A numeric attribute named `"timeout"` +/// +/// An appropriate DynamoDB table can be created with the CLI as follows: +/// +/// ```bash +/// $ aws dynamodb create-table --table-name --key-schema AttributeName=path,KeyType=HASH AttributeName=etag,KeyType=RANGE --attribute-definitions AttributeName=path,AttributeType=S AttributeName=etag,AttributeType=S +/// $ aws dynamodb update-time-to-live --table-name --time-to-live-specification Enabled=true,AttributeName=ttl +/// ``` +/// +/// To perform a conditional operation on an object with a given `path` and `etag` (`*` if creating), +/// the commit protocol is as follows: +/// +/// 1. Perform HEAD request on `path` and error on precondition mismatch +/// 2. Create record in DynamoDB with given `path` and `etag` with the configured timeout +/// 1. On Success: Perform operation with the configured timeout +/// 2. On Conflict: +/// 1. Periodically re-perform HEAD request on `path` and error on precondition mismatch +/// 2. If `timeout * max_skew_rate` passed, replace the record incrementing the `"generation"` +/// 1. On Success: GOTO 2.1 +/// 2. On Conflict: GOTO 2.2 +/// +/// Provided no writer modifies an object with a given `path` and `etag` without first adding a +/// corresponding record to DynamoDB, we are guaranteed that only one writer will ever commit. +/// +/// This is inspired by the [DynamoDB Lock Client] but simplified for the more limited +/// requirements of synchronizing object storage. The major changes are: +/// +/// * Uses a monotonic generation count instead of a UUID rvn, as this is: +/// * Cheaper to generate, serialize and compare +/// * Cannot collide +/// * More human readable / interpretable +/// * Relies on [TTL] to eventually clean up old locks +/// +/// It also draws inspiration from the DeltaLake [S3 Multi-Cluster] commit protocol, but +/// generalised to not make assumptions about the workload and not rely on first writing +/// to a temporary path. +/// +/// [TTL]: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/howitworks-ttl.html +/// [DynamoDB Lock Client]: https://aws.amazon.com/blogs/database/building-distributed-locks-with-the-dynamodb-lock-client/ +/// [S3 Multi-Cluster]: https://docs.google.com/document/d/1Gs4ZsTH19lMxth4BSdwlWjUNR-XhKHicDvBjd2RqNd8/edit#heading=h.mjjuxw9mcz9h +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct DynamoCommit { + table_name: String, + /// The number of milliseconds a lease is valid for + timeout: u64, + /// The maximum clock skew rate tolerated by the system + max_clock_skew_rate: u32, + /// The length of time a record will be retained in DynamoDB before being cleaned up + /// + /// This is purely an optimisation to avoid indefinite growth of the DynamoDB table + /// and does not impact how long clients may wait to acquire a lock + ttl: Duration, + /// The backoff duration before retesting a condition + test_interval: Duration, +} + +impl DynamoCommit { + /// Create a new [`DynamoCommit`] with a given table name + pub fn new(table_name: String) -> Self { + Self { + table_name, + timeout: 20_000, + max_clock_skew_rate: 3, + ttl: Duration::from_secs(60 * 60), + test_interval: Duration::from_millis(100), + } + } + + /// Overrides the lock timeout. + /// + /// A longer lock timeout reduces the probability of spurious commit failures and multi-writer + /// races, but will increase the time that writers must wait to reclaim a lock lost. The + /// default value of 20 seconds should be appropriate for must use-cases. + pub fn with_timeout(mut self, millis: u64) -> Self { + self.timeout = millis; + self + } + + /// The maximum clock skew rate tolerated by the system. + /// + /// An environment in which the clock on the fastest node ticks twice as fast as the slowest + /// node, would have a clock skew rate of 2. The default value of 3 should be appropriate + /// for most environments. + pub fn with_max_clock_skew_rate(mut self, rate: u32) -> Self { + self.max_clock_skew_rate = rate; + self + } + + /// The length of time a record should be retained in DynamoDB before being cleaned up + /// + /// This should be significantly larger than the configured lock timeout, with the default + /// value of 1 hour appropriate for most use-cases. + pub fn with_ttl(mut self, ttl: Duration) -> Self { + self.ttl = ttl; + self + } + + /// Parse [`DynamoCommit`] from a string + pub(crate) fn from_str(value: &str) -> Option { + Some(match value.split_once(':') { + Some((table_name, timeout)) => { + Self::new(table_name.trim().to_string()).with_timeout(timeout.parse().ok()?) + } + None => Self::new(value.trim().to_string()), + }) + } + + /// Returns the name of the DynamoDB table. + pub(crate) fn table_name(&self) -> &str { + &self.table_name + } + + pub(crate) async fn copy_if_not_exists( + &self, + client: &S3Client, + from: &Path, + to: &Path, + ) -> Result<()> { + self.conditional_op(client, to, None, || async { + client.copy_request(from, to).send().await?; + Ok(()) + }) + .await + } + + #[allow(clippy::future_not_send)] // Generics confound this lint + pub(crate) async fn conditional_op( + &self, + client: &S3Client, + to: &Path, + etag: Option<&str>, + op: F, + ) -> Result + where + F: FnOnce() -> Fut, + Fut: Future>, + { + check_precondition(client, to, etag).await?; + + let mut previous_lease = None; + + loop { + let existing = previous_lease.as_ref(); + match self.try_lock(client, to.as_ref(), etag, existing).await? { + TryLockResult::Ok(lease) => { + let expiry = lease.acquire + lease.timeout; + return match tokio::time::timeout_at(expiry.into(), op()).await { + Ok(Ok(v)) => Ok(v), + Ok(Err(e)) => Err(e), + Err(_) => Err(Error::Generic { + store: "DynamoDB", + source: format!( + "Failed to perform conditional operation in {} milliseconds", + self.timeout + ) + .into(), + }), + }; + } + TryLockResult::Conflict(conflict) => { + let mut interval = tokio::time::interval(self.test_interval); + let expiry = conflict.timeout * self.max_clock_skew_rate; + loop { + interval.tick().await; + check_precondition(client, to, etag).await?; + if conflict.acquire.elapsed() > expiry { + previous_lease = Some(conflict); + break; + } + } + } + } + } + } + + /// Attempt to acquire a lock, reclaiming an existing lease if provided + async fn try_lock( + &self, + s3: &S3Client, + path: &str, + etag: Option<&str>, + existing: Option<&Lease>, + ) -> Result { + let attributes; + let (next_gen, condition_expression, expression_attribute_values) = match existing { + None => (0_u64, "attribute_not_exists(#pk)", Map(&[])), + Some(existing) => { + attributes = [(":g", AttributeValue::Number(existing.generation))]; + ( + existing.generation.checked_add(1).unwrap(), + "attribute_exists(#pk) AND generation = :g", + Map(attributes.as_slice()), + ) + } + }; + + let ttl = (Utc::now() + self.ttl).timestamp(); + let items = [ + ("path", AttributeValue::from(path)), + ("etag", AttributeValue::from(etag.unwrap_or("*"))), + ("generation", AttributeValue::Number(next_gen)), + ("timeout", AttributeValue::Number(self.timeout)), + ("ttl", AttributeValue::Number(ttl as _)), + ]; + let names = [("#pk", "path")]; + + let req = PutItem { + table_name: &self.table_name, + condition_expression, + expression_attribute_values, + expression_attribute_names: Map(&names), + item: Map(&items), + return_values: None, + return_values_on_condition_check_failure: Some(ReturnValues::AllOld), + }; + + let credential = s3.config.get_credential().await?; + + let acquire = Instant::now(); + match self + .request(s3, credential.as_deref(), "DynamoDB_20120810.PutItem", req) + .await + { + Ok(_) => Ok(TryLockResult::Ok(Lease { + acquire, + generation: next_gen, + timeout: Duration::from_millis(self.timeout), + })), + Err(e) => match parse_error_response(&e) { + Some(e) if e.error.ends_with(CONFLICT) => match extract_lease(&e.item) { + Some(lease) => Ok(TryLockResult::Conflict(lease)), + None => Err(Error::Generic { + store: STORE, + source: "Failed to extract lease from conflict ReturnValuesOnConditionCheckFailure response".into() + }), + }, + _ => Err(Error::Generic { + store: STORE, + source: Box::new(e), + }), + }, + } + } + + async fn request( + &self, + s3: &S3Client, + cred: Option<&AwsCredential>, + target: &str, + req: R, + ) -> Result { + let region = &s3.config.region; + let authorizer = cred.map(|x| AwsAuthorizer::new(x, "dynamodb", region)); + + let builder = match &s3.config.endpoint { + Some(e) => s3.client.request(Method::POST, e), + None => { + let url = format!("https://dynamodb.{region}.amazonaws.com"); + s3.client.request(Method::POST, url) + } + }; + + // TODO: Timeout + builder + .json(&req) + .header("X-Amz-Target", target) + .with_aws_sigv4(authorizer, None) + .send_retry(&s3.config.retry_config) + .await + } +} + +#[derive(Debug)] +enum TryLockResult { + /// Successfully acquired a lease + Ok(Lease), + /// An existing lease was found + Conflict(Lease), +} + +/// Validates that `path` has the given `etag` or doesn't exist if `None` +async fn check_precondition(client: &S3Client, path: &Path, etag: Option<&str>) -> Result<()> { + let options = GetOptions { + head: true, + ..Default::default() + }; + + match etag { + Some(expected) => match client.get_opts(path, options).await { + Ok(r) => match r.meta.e_tag { + Some(actual) if expected == actual => Ok(()), + actual => Err(Error::Precondition { + path: path.to_string(), + source: format!("{} does not match {expected}", actual.unwrap_or_default()) + .into(), + }), + }, + Err(Error::NotFound { .. }) => Err(Error::Precondition { + path: path.to_string(), + source: format!("Object at location {path} not found").into(), + }), + Err(e) => Err(e), + }, + None => match client.get_opts(path, options).await { + Ok(_) => Err(Error::AlreadyExists { + path: path.to_string(), + source: "Already Exists".to_string().into(), + }), + Err(Error::NotFound { .. }) => Ok(()), + Err(e) => Err(e), + }, + } +} + +/// Parses the error response if any +fn parse_error_response(e: &RetryError) -> Option> { + match e.inner() { + RequestError::Status { + status: StatusCode::BAD_REQUEST, + body: Some(b), + } => serde_json::from_str(b).ok(), + _ => None, + } +} + +/// Extracts a lease from `item`, returning `None` on error +fn extract_lease(item: &HashMap<&str, AttributeValue<'_>>) -> Option { + let generation = match item.get("generation") { + Some(AttributeValue::Number(generation)) => generation, + _ => return None, + }; + + let timeout = match item.get("timeout") { + Some(AttributeValue::Number(timeout)) => *timeout, + _ => return None, + }; + + Some(Lease { + acquire: Instant::now(), + generation: *generation, + timeout: Duration::from_millis(timeout), + }) +} + +/// A lock lease +#[derive(Debug, Clone)] +struct Lease { + acquire: Instant, + generation: u64, + timeout: Duration, +} + +/// A DynamoDB [PutItem] payload +/// +/// [PutItem]: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_PutItem.html +#[derive(Serialize)] +#[serde(rename_all = "PascalCase")] +struct PutItem<'a> { + /// The table name + table_name: &'a str, + + /// A condition that must be satisfied in order for a conditional PutItem operation to succeed. + condition_expression: &'a str, + + /// One or more substitution tokens for attribute names in an expression + expression_attribute_names: Map<'a, &'a str, &'a str>, + + /// One or more values that can be substituted in an expression + expression_attribute_values: Map<'a, &'a str, AttributeValue<'a>>, + + /// A map of attribute name/value pairs, one for each attribute + item: Map<'a, &'a str, AttributeValue<'a>>, + + /// Use ReturnValues if you want to get the item attributes as they appeared + /// before they were updated with the PutItem request. + #[serde(skip_serializing_if = "Option::is_none")] + return_values: Option, + + /// An optional parameter that returns the item attributes for a PutItem operation + /// that failed a condition check. + #[serde(skip_serializing_if = "Option::is_none")] + return_values_on_condition_check_failure: Option, +} + +#[derive(Deserialize)] +struct ErrorResponse<'a> { + #[serde(rename = "__type")] + error: &'a str, + + #[serde(borrow, default, rename = "Item")] + item: HashMap<&'a str, AttributeValue<'a>>, +} + +#[derive(Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +enum ReturnValues { + AllOld, +} + +/// A collection of key value pairs +/// +/// This provides cheap, ordered serialization of maps +struct Map<'a, K, V>(&'a [(K, V)]); + +impl Serialize for Map<'_, K, V> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.0.is_empty() { + return serializer.serialize_none(); + } + let mut map = serializer.serialize_map(Some(self.0.len()))?; + for (k, v) in self.0 { + map.serialize_entry(k, v)? + } + map.end() + } +} + +/// A DynamoDB [AttributeValue] +/// +/// [AttributeValue]: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html +#[derive(Debug, Serialize, Deserialize)] +enum AttributeValue<'a> { + #[serde(rename = "S")] + String(Cow<'a, str>), + #[serde(rename = "N", with = "number")] + Number(u64), +} + +impl<'a> From<&'a str> for AttributeValue<'a> { + fn from(value: &'a str) -> Self { + Self::String(Cow::Borrowed(value)) + } +} + +/// Numbers are serialized as strings +mod number { + use serde::{Deserialize, Deserializer, Serializer}; + + pub(crate) fn serialize(v: &u64, s: S) -> Result { + s.serialize_str(&v.to_string()) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result { + let v: &str = Deserialize::deserialize(d)?; + v.parse().map_err(serde::de::Error::custom) + } +} + +use crate::client::HttpResponse; +/// Re-export integration_test to be called by s3_test +#[cfg(test)] +pub(crate) use tests::integration_test; + +#[cfg(test)] +mod tests { + use super::*; + use crate::aws::AmazonS3; + use crate::ObjectStore; + use rand::distributions::Alphanumeric; + use rand::{thread_rng, Rng}; + + #[test] + fn test_attribute_serde() { + let serde = serde_json::to_string(&AttributeValue::Number(23)).unwrap(); + assert_eq!(serde, "{\"N\":\"23\"}"); + let back: AttributeValue<'_> = serde_json::from_str(&serde).unwrap(); + assert!(matches!(back, AttributeValue::Number(23))); + } + + /// An integration test for DynamoDB + /// + /// This is a function called by s3_test to avoid test concurrency issues + pub(crate) async fn integration_test(integration: &AmazonS3, d: &DynamoCommit) { + let client = integration.client.as_ref(); + + let src = Path::from("dynamo_path_src"); + integration.put(&src, "asd".into()).await.unwrap(); + + let dst = Path::from("dynamo_path"); + let _ = integration.delete(&dst).await; // Delete if present + + // Create a lock if not already exists + let existing = match d.try_lock(client, dst.as_ref(), None, None).await.unwrap() { + TryLockResult::Conflict(l) => l, + TryLockResult::Ok(l) => l, + }; + + // Should not be able to acquire a lock again + let r = d.try_lock(client, dst.as_ref(), None, None).await; + assert!(matches!(r, Ok(TryLockResult::Conflict(_)))); + + // But should still be able to reclaim lock and perform copy + d.copy_if_not_exists(client, &src, &dst).await.unwrap(); + + match d.try_lock(client, dst.as_ref(), None, None).await.unwrap() { + TryLockResult::Conflict(new) => { + // Should have incremented generation to do so + assert_eq!(new.generation, existing.generation + 1); + } + _ => panic!("Should conflict"), + } + + let rng = thread_rng(); + let etag = String::from_utf8(rng.sample_iter(Alphanumeric).take(32).collect()).unwrap(); + let t = Some(etag.as_str()); + + let l = match d.try_lock(client, dst.as_ref(), t, None).await.unwrap() { + TryLockResult::Ok(l) => l, + _ => panic!("should not conflict"), + }; + + match d.try_lock(client, dst.as_ref(), t, None).await.unwrap() { + TryLockResult::Conflict(c) => assert_eq!(l.generation, c.generation), + _ => panic!("should conflict"), + } + + match d.try_lock(client, dst.as_ref(), t, Some(&l)).await.unwrap() { + TryLockResult::Ok(new) => assert_eq!(new.generation, l.generation + 1), + _ => panic!("should not conflict"), + } + } +} diff --git a/src/aws/mod.rs b/src/aws/mod.rs new file mode 100644 index 0000000..b8175bd --- /dev/null +++ b/src/aws/mod.rs @@ -0,0 +1,823 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for S3 +//! +//! ## Multipart uploads +//! +//! Multipart uploads can be initiated with the [ObjectStore::put_multipart] method. +//! +//! If the writer fails for any reason, you may have parts uploaded to AWS but not +//! used that you will be charged for. [`MultipartUpload::abort`] may be invoked to drop +//! these unneeded parts, however, it is recommended that you consider implementing +//! [automatic cleanup] of unused parts that are older than some threshold. +//! +//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ + +use async_trait::async_trait; +use futures::stream::BoxStream; +use futures::{StreamExt, TryStreamExt}; +use reqwest::header::{HeaderName, IF_MATCH, IF_NONE_MATCH}; +use reqwest::{Method, StatusCode}; +use std::{sync::Arc, time::Duration}; +use url::Url; + +use crate::aws::client::{CompleteMultipartMode, PutPartPayload, RequestError, S3Client}; +use crate::client::get::GetClientExt; +use crate::client::list::ListClientExt; +use crate::client::CredentialProvider; +use crate::multipart::{MultipartStore, PartId}; +use crate::signer::Signer; +use crate::util::STRICT_ENCODE_SET; +use crate::{ + Error, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta, + ObjectStore, Path, PutMode, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, + UploadPart, +}; + +static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging"); +static COPY_SOURCE_HEADER: HeaderName = HeaderName::from_static("x-amz-copy-source"); + +mod builder; +mod checksum; +mod client; +mod credential; +mod dynamo; +mod precondition; + +#[cfg(not(target_arch = "wasm32"))] +mod resolve; + +pub use builder::{AmazonS3Builder, AmazonS3ConfigKey}; +pub use checksum::Checksum; +pub use dynamo::DynamoCommit; +pub use precondition::{S3ConditionalPut, S3CopyIfNotExists}; + +#[cfg(not(target_arch = "wasm32"))] +pub use resolve::resolve_bucket_region; + +/// This struct is used to maintain the URI path encoding +const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = STRICT_ENCODE_SET.remove(b'/'); + +const STORE: &str = "S3"; + +/// [`CredentialProvider`] for [`AmazonS3`] +pub type AwsCredentialProvider = Arc>; +use crate::client::parts::Parts; +pub use credential::{AwsAuthorizer, AwsCredential}; + +/// Interface for [Amazon S3](https://aws.amazon.com/s3/). +#[derive(Debug, Clone)] +pub struct AmazonS3 { + client: Arc, +} + +impl std::fmt::Display for AmazonS3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "AmazonS3({})", self.client.config.bucket) + } +} + +impl AmazonS3 { + /// Returns the [`AwsCredentialProvider`] used by [`AmazonS3`] + pub fn credentials(&self) -> &AwsCredentialProvider { + &self.client.config.credentials + } + + /// Create a full URL to the resource specified by `path` with this instance's configuration. + fn path_url(&self, path: &Path) -> String { + self.client.config.path_url(path) + } +} + +#[async_trait] +impl Signer for AmazonS3 { + /// Create a URL containing the relevant [AWS SigV4] query parameters that authorize a request + /// via `method` to the resource at `path` valid for the duration specified in `expires_in`. + /// + /// [AWS SigV4]: https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html + /// + /// # Example + /// + /// This example returns a URL that will enable a user to upload a file to + /// "some-folder/some-file.txt" in the next hour. + /// + /// ``` + /// # async fn example() -> Result<(), Box> { + /// # use object_store::{aws::AmazonS3Builder, path::Path, signer::Signer}; + /// # use reqwest::Method; + /// # use std::time::Duration; + /// # + /// let region = "us-east-1"; + /// let s3 = AmazonS3Builder::new() + /// .with_region(region) + /// .with_bucket_name("my-bucket") + /// .with_access_key_id("my-access-key-id") + /// .with_secret_access_key("my-secret-access-key") + /// .build()?; + /// + /// let url = s3.signed_url( + /// Method::PUT, + /// &Path::from("some-folder/some-file.txt"), + /// Duration::from_secs(60 * 60) + /// ).await?; + /// # Ok(()) + /// # } + /// ``` + async fn signed_url(&self, method: Method, path: &Path, expires_in: Duration) -> Result { + let credential = self.credentials().get_credential().await?; + let authorizer = AwsAuthorizer::new(&credential, "s3", &self.client.config.region) + .with_request_payer(self.client.config.request_payer); + + let path_url = self.path_url(path); + let mut url = path_url.parse().map_err(|e| Error::Generic { + store: STORE, + source: format!("Unable to parse url {path_url}: {e}").into(), + })?; + + authorizer.sign(method, &mut url, expires_in); + + Ok(url) + } +} + +#[async_trait] +impl ObjectStore for AmazonS3 { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let PutOptions { + mode, + tags, + attributes, + extensions, + } = opts; + + let request = self + .client + .request(Method::PUT, location) + .with_payload(payload) + .with_attributes(attributes) + .with_tags(tags) + .with_extensions(extensions) + .with_encryption_headers(); + + match (mode, &self.client.config.conditional_put) { + (PutMode::Overwrite, _) => request.idempotent(true).do_put().await, + (PutMode::Create, S3ConditionalPut::Disabled) => Err(Error::NotImplemented), + (PutMode::Create, S3ConditionalPut::ETagMatch) => { + match request.header(&IF_NONE_MATCH, "*").do_put().await { + // Technically If-None-Match should return NotModified but some stores, + // such as R2, instead return PreconditionFailed + // https://developers.cloudflare.com/r2/api/s3/extensions/#conditional-operations-in-putobject + Err(e @ Error::NotModified { .. } | e @ Error::Precondition { .. }) => { + Err(Error::AlreadyExists { + path: location.to_string(), + source: Box::new(e), + }) + } + r => r, + } + } + (PutMode::Create, S3ConditionalPut::Dynamo(d)) => { + d.conditional_op(&self.client, location, None, move || request.do_put()) + .await + } + (PutMode::Update(v), put) => { + let etag = v.e_tag.ok_or_else(|| Error::Generic { + store: STORE, + source: "ETag required for conditional put".to_string().into(), + })?; + match put { + S3ConditionalPut::ETagMatch => { + match request + .header(&IF_MATCH, etag.as_str()) + // Real S3 will occasionally report 409 Conflict + // if there are concurrent `If-Match` requests + // in flight, so we need to be prepared to retry + // 409 responses. + .retry_on_conflict(true) + .do_put() + .await + { + // Real S3 reports NotFound rather than PreconditionFailed when the + // object doesn't exist. Convert to PreconditionFailed for + // consistency with R2. This also matches what the HTTP spec + // says the behavior should be. + Err(Error::NotFound { path, source }) => { + Err(Error::Precondition { path, source }) + } + r => r, + } + } + S3ConditionalPut::Dynamo(d) => { + d.conditional_op(&self.client, location, Some(&etag), move || { + request.do_put() + }) + .await + } + S3ConditionalPut::Disabled => Err(Error::NotImplemented), + } + } + } + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + let upload_id = self.client.create_multipart(location, opts).await?; + + Ok(Box::new(S3MultiPartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + location: location.clone(), + upload_id: upload_id.clone(), + parts: Default::default(), + }), + })) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.client.get_opts(location, options).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.request(Method::DELETE, location).send().await?; + Ok(()) + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + locations + .try_chunks(1_000) + .map(move |locations| async { + // Early return the error. We ignore the paths that have already been + // collected into the chunk. + let locations = locations.map_err(|e| e.1)?; + self.client + .bulk_delete_request(locations) + .await + .map(futures::stream::iter) + }) + .buffered(20) + .try_flatten() + .boxed() + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.client.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + if self.client.config.is_s3_express() { + let offset = offset.clone(); + // S3 Express does not support start-after + return self + .client + .list(prefix) + .try_filter(move |f| futures::future::ready(f.location > offset)) + .boxed(); + } + + self.client.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.client.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client + .copy_request(from, to) + .idempotent(true) + .send() + .await?; + Ok(()) + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let (k, v, status) = match &self.client.config.copy_if_not_exists { + Some(S3CopyIfNotExists::Header(k, v)) => (k, v, StatusCode::PRECONDITION_FAILED), + Some(S3CopyIfNotExists::HeaderWithStatus(k, v, status)) => (k, v, *status), + Some(S3CopyIfNotExists::Multipart) => { + let upload_id = self + .client + .create_multipart(to, PutMultipartOpts::default()) + .await?; + + let res = async { + let part_id = self + .client + .put_part(to, &upload_id, 0, PutPartPayload::Copy(from)) + .await?; + match self + .client + .complete_multipart( + to, + &upload_id, + vec![part_id], + CompleteMultipartMode::Create, + ) + .await + { + Err(e @ Error::Precondition { .. }) => Err(Error::AlreadyExists { + path: to.to_string(), + source: Box::new(e), + }), + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + .await; + + // If the multipart upload failed, make a best effort attempt to + // clean it up. It's the caller's responsibility to add a + // lifecycle rule if guaranteed cleanup is required, as we + // cannot protect against an ill-timed process crash. + if res.is_err() { + let _ = self.client.abort_multipart(to, &upload_id).await; + } + + return res; + } + Some(S3CopyIfNotExists::Dynamo(lock)) => { + return lock.copy_if_not_exists(&self.client, from, to).await + } + None => { + return Err(Error::NotSupported { + source: "S3 does not support copy-if-not-exists".to_string().into(), + }) + } + }; + + let req = self.client.copy_request(from, to); + match req.header(k, v).send().await { + Err(RequestError::Retry { source, path }) if source.status() == Some(status) => { + Err(Error::AlreadyExists { + source: Box::new(source), + path, + }) + } + Err(e) => Err(e.into()), + Ok(_) => Ok(()), + } + } +} + +#[derive(Debug)] +struct S3MultiPartUpload { + part_idx: usize, + state: Arc, +} + +#[derive(Debug)] +struct UploadState { + parts: Parts, + location: Path, + upload_id: String, + client: Arc, +} + +#[async_trait] +impl MultipartUpload for S3MultiPartUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state + .client + .put_part( + &state.location, + &state.upload_id, + idx, + PutPartPayload::Part(data), + ) + .await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .complete_multipart( + &self.state.location, + &self.state.upload_id, + parts, + CompleteMultipartMode::Overwrite, + ) + .await + } + + async fn abort(&mut self) -> Result<()> { + self.state + .client + .request(Method::DELETE, &self.state.location) + .query(&[("uploadId", &self.state.upload_id)]) + .idempotent(true) + .send() + .await?; + + Ok(()) + } +} + +#[async_trait] +impl MultipartStore for AmazonS3 { + async fn create_multipart(&self, path: &Path) -> Result { + self.client + .create_multipart(path, PutMultipartOpts::default()) + .await + } + + async fn put_part( + &self, + path: &Path, + id: &MultipartId, + part_idx: usize, + data: PutPayload, + ) -> Result { + self.client + .put_part(path, id, part_idx, PutPartPayload::Part(data)) + .await + } + + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + parts: Vec, + ) -> Result { + self.client + .complete_multipart(path, id, parts, CompleteMultipartMode::Overwrite) + .await + } + + async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()> { + self.client + .request(Method::DELETE, path) + .query(&[("uploadId", id)]) + .send() + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::get::GetClient; + use crate::integration::*; + use crate::tests::*; + use crate::ClientOptions; + use base64::prelude::BASE64_STANDARD; + use base64::Engine; + use http::HeaderMap; + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + #[tokio::test] + async fn write_multipart_file_with_signature() { + maybe_skip_integration!(); + + let store = AmazonS3Builder::from_env() + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + + let str = "test.bin"; + let path = Path::parse(str).unwrap(); + let opts = PutMultipartOpts::default(); + let mut upload = store.put_multipart_opts(&path, opts).await.unwrap(); + + upload + .put_part(PutPayload::from(vec![0u8; 10_000_000])) + .await + .unwrap(); + upload + .put_part(PutPayload::from(vec![0u8; 5_000_000])) + .await + .unwrap(); + + let res = upload.complete().await.unwrap(); + assert!(res.e_tag.is_some(), "Should have valid etag"); + + store.delete(&path).await.unwrap(); + } + + #[tokio::test] + async fn write_multipart_file_with_signature_object_lock() { + maybe_skip_integration!(); + + let bucket = "test-object-lock"; + let store = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + + let str = "test.bin"; + let path = Path::parse(str).unwrap(); + let opts = PutMultipartOpts::default(); + let mut upload = store.put_multipart_opts(&path, opts).await.unwrap(); + + upload + .put_part(PutPayload::from(vec![0u8; 10_000_000])) + .await + .unwrap(); + upload + .put_part(PutPayload::from(vec![0u8; 5_000_000])) + .await + .unwrap(); + + let res = upload.complete().await.unwrap(); + assert!(res.e_tag.is_some(), "Should have valid etag"); + + store.delete(&path).await.unwrap(); + } + + #[tokio::test] + async fn s3_test() { + maybe_skip_integration!(); + let config = AmazonS3Builder::from_env(); + + let integration = config.build().unwrap(); + let config = &integration.client.config; + let test_not_exists = config.copy_if_not_exists.is_some(); + let test_conditional_put = config.conditional_put != S3ConditionalPut::Disabled; + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + stream_get(&integration).await; + multipart(&integration, &integration).await; + multipart_race_condition(&integration, true).await; + multipart_out_of_order(&integration).await; + signing(&integration).await; + s3_encryption(&integration).await; + put_get_attributes(&integration).await; + + // Object tagging is not supported by S3 Express One Zone + if config.session_provider.is_none() { + tagging( + Arc::new(AmazonS3 { + client: Arc::clone(&integration.client), + }), + !config.disable_tagging, + |p| { + let client = Arc::clone(&integration.client); + async move { client.get_object_tagging(&p).await } + }, + ) + .await; + } + + if test_not_exists { + copy_if_not_exists(&integration).await; + } + if test_conditional_put { + put_opts(&integration, true).await; + } + + // run integration test with unsigned payload enabled + let builder = AmazonS3Builder::from_env().with_unsigned_payload(true); + let integration = builder.build().unwrap(); + put_get_delete_list(&integration).await; + + // run integration test with checksum set to sha256 + let builder = AmazonS3Builder::from_env().with_checksum_algorithm(Checksum::SHA256); + let integration = builder.build().unwrap(); + put_get_delete_list(&integration).await; + + match &integration.client.config.copy_if_not_exists { + Some(S3CopyIfNotExists::Dynamo(d)) => dynamo::integration_test(&integration, d).await, + _ => eprintln!("Skipping dynamo integration test - dynamo not configured"), + }; + } + + #[tokio::test] + async fn s3_test_get_nonexistent_location() { + maybe_skip_integration!(); + let integration = AmazonS3Builder::from_env().build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_get_nonexistent_bucket() { + maybe_skip_integration!(); + let config = AmazonS3Builder::from_env().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.get(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_put_nonexistent_bucket() { + maybe_skip_integration!(); + let config = AmazonS3Builder::from_env().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + let data = PutPayload::from("arbitrary data"); + + let err = integration.put(&location, data).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_delete_nonexistent_location() { + maybe_skip_integration!(); + let integration = AmazonS3Builder::from_env().build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + integration.delete(&location).await.unwrap(); + } + + #[tokio::test] + async fn s3_test_delete_nonexistent_bucket() { + maybe_skip_integration!(); + let config = AmazonS3Builder::from_env().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + #[ignore = "Tests shouldn't call use remote services by default"] + async fn test_disable_creds() { + // https://registry.opendata.aws/daylight-osm/ + let v1 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_access_key_id("local") + .with_secret_access_key("development") + .build() + .unwrap(); + + let prefix = Path::from("release"); + + v1.list_with_delimiter(Some(&prefix)).await.unwrap_err(); + + let v2 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_skip_signature(true) + .build() + .unwrap(); + + v2.list_with_delimiter(Some(&prefix)).await.unwrap(); + } + + async fn s3_encryption(store: &AmazonS3) { + maybe_skip_integration!(); + + let data = PutPayload::from(vec![3u8; 1024]); + + let encryption_headers: HeaderMap = store.client.config.encryption_headers.clone().into(); + let expected_encryption = + if let Some(encryption_type) = encryption_headers.get("x-amz-server-side-encryption") { + encryption_type + } else { + eprintln!("Skipping S3 encryption test - encryption not configured"); + return; + }; + + let locations = [ + Path::from("test-encryption-1"), + Path::from("test-encryption-2"), + Path::from("test-encryption-3"), + ]; + + store.put(&locations[0], data.clone()).await.unwrap(); + store.copy(&locations[0], &locations[1]).await.unwrap(); + + let mut upload = store.put_multipart(&locations[2]).await.unwrap(); + upload.put_part(data.clone()).await.unwrap(); + upload.complete().await.unwrap(); + + for location in &locations { + let res = store + .client + .get_request(location, GetOptions::default()) + .await + .unwrap(); + let headers = res.headers(); + assert_eq!( + headers + .get("x-amz-server-side-encryption") + .expect("object is not encrypted"), + expected_encryption + ); + + store.delete(location).await.unwrap(); + } + } + + /// See CONTRIBUTING.md for the MinIO setup for this test. + #[tokio::test] + async fn test_s3_ssec_encryption_with_minio() { + if std::env::var("TEST_S3_SSEC_ENCRYPTION").is_err() { + eprintln!("Skipping S3 SSE-C encryption test"); + return; + } + eprintln!("Running S3 SSE-C encryption test"); + + let customer_key = "1234567890abcdef1234567890abcdef"; + let expected_md5 = "JMwgiexXqwuPqIPjYFmIZQ=="; + + let store = AmazonS3Builder::from_env() + .with_ssec_encryption(BASE64_STANDARD.encode(customer_key)) + .with_client_options(ClientOptions::default().with_allow_invalid_certificates(true)) + .build() + .unwrap(); + + let data = PutPayload::from(vec![3u8; 1024]); + + let locations = [ + Path::from("test-encryption-1"), + Path::from("test-encryption-2"), + Path::from("test-encryption-3"), + ]; + + // Test put with sse-c. + store.put(&locations[0], data.clone()).await.unwrap(); + + // Test copy with sse-c. + store.copy(&locations[0], &locations[1]).await.unwrap(); + + // Test multipart upload with sse-c. + let mut upload = store.put_multipart(&locations[2]).await.unwrap(); + upload.put_part(data.clone()).await.unwrap(); + upload.complete().await.unwrap(); + + // Test get with sse-c. + for location in &locations { + let res = store + .client + .get_request(location, GetOptions::default()) + .await + .unwrap(); + let headers = res.headers(); + assert_eq!( + headers + .get("x-amz-server-side-encryption-customer-algorithm") + .expect("object is not encrypted with SSE-C"), + "AES256" + ); + + assert_eq!( + headers + .get("x-amz-server-side-encryption-customer-key-MD5") + .expect("object is not encrypted with SSE-C"), + expected_md5 + ); + + store.delete(location).await.unwrap(); + } + } +} diff --git a/src/aws/precondition.rs b/src/aws/precondition.rs new file mode 100644 index 0000000..ab5aea9 --- /dev/null +++ b/src/aws/precondition.rs @@ -0,0 +1,278 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::dynamo::DynamoCommit; +use crate::config::Parse; + +use itertools::Itertools; + +/// Configure how to provide [`ObjectStore::copy_if_not_exists`] for [`AmazonS3`]. +/// +/// [`ObjectStore::copy_if_not_exists`]: crate::ObjectStore::copy_if_not_exists +/// [`AmazonS3`]: super::AmazonS3 +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum S3CopyIfNotExists { + /// Some S3-compatible stores, such as Cloudflare R2, support copy if not exists + /// semantics through custom headers. + /// + /// If set, [`ObjectStore::copy_if_not_exists`] will perform a normal copy operation + /// with the provided header pair, and expect the store to fail with `412 Precondition Failed` + /// if the destination file already exists. + /// + /// Encoded as `header::` ignoring whitespace + /// + /// For example `header: cf-copy-destination-if-none-match: *`, would set + /// the header `cf-copy-destination-if-none-match` to `*` + /// + /// [`ObjectStore::copy_if_not_exists`]: crate::ObjectStore::copy_if_not_exists + Header(String, String), + /// The same as [`S3CopyIfNotExists::Header`] but allows custom status code checking, for object stores that return values + /// other than 412. + /// + /// Encoded as `header-with-status:::` ignoring whitespace + HeaderWithStatus(String, String, reqwest::StatusCode), + /// Native Amazon S3 supports copy if not exists through a multipart upload + /// where the upload copies an existing object and is completed only if the + /// new object does not already exist. + /// + /// WARNING: When using this mode, `copy_if_not_exists` does not copy tags + /// or attributes from the source object. + /// + /// WARNING: When using this mode, `copy_if_not_exists` makes only a best + /// effort attempt to clean up the multipart upload if the copy operation + /// fails. Consider using a lifecycle rule to automatically clean up + /// abandoned multipart uploads. See [the module + /// docs](super#multipart-uploads) for details. + /// + /// Encoded as `multipart` ignoring whitespace. + Multipart, + /// The name of a DynamoDB table to use for coordination + /// + /// Encoded as either `dynamo:` or `dynamo::` + /// ignoring whitespace. The default timeout is used if not specified + /// + /// See [`DynamoCommit`] for more information + /// + /// This will use the same region, credentials and endpoint as configured for S3 + Dynamo(DynamoCommit), +} + +impl std::fmt::Display for S3CopyIfNotExists { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Header(k, v) => write!(f, "header: {}: {}", k, v), + Self::HeaderWithStatus(k, v, code) => { + write!(f, "header-with-status: {k}: {v}: {}", code.as_u16()) + } + Self::Multipart => f.write_str("multipart"), + Self::Dynamo(lock) => write!(f, "dynamo: {}", lock.table_name()), + } + } +} + +impl S3CopyIfNotExists { + fn from_str(s: &str) -> Option { + if s.trim() == "multipart" { + return Some(Self::Multipart); + }; + + let (variant, value) = s.split_once(':')?; + match variant.trim() { + "header" => { + let (k, v) = value.split_once(':')?; + Some(Self::Header(k.trim().to_string(), v.trim().to_string())) + } + "header-with-status" => { + let (k, v, status) = value.split(':').collect_tuple()?; + + let code = status.trim().parse().ok()?; + + Some(Self::HeaderWithStatus( + k.trim().to_string(), + v.trim().to_string(), + code, + )) + } + "dynamo" => Some(Self::Dynamo(DynamoCommit::from_str(value)?)), + _ => None, + } + } +} + +impl Parse for S3CopyIfNotExists { + fn parse(v: &str) -> crate::Result { + Self::from_str(v).ok_or_else(|| crate::Error::Generic { + store: "Config", + source: format!("Failed to parse \"{v}\" as S3CopyIfNotExists").into(), + }) + } +} + +/// Configure how to provide conditional put support for [`AmazonS3`]. +/// +/// [`AmazonS3`]: super::AmazonS3 +#[derive(Debug, Clone, Eq, PartialEq, Default)] +#[allow(missing_copy_implementations)] +#[non_exhaustive] +pub enum S3ConditionalPut { + /// Some S3-compatible stores, such as Cloudflare R2 and minio support conditional + /// put using the standard [HTTP precondition] headers If-Match and If-None-Match + /// + /// Encoded as `etag` ignoring whitespace + /// + /// [HTTP precondition]: https://datatracker.ietf.org/doc/html/rfc9110#name-preconditions + #[default] + ETagMatch, + + /// The name of a DynamoDB table to use for coordination + /// + /// Encoded as either `dynamo:` or `dynamo::` + /// ignoring whitespace. The default timeout is used if not specified + /// + /// See [`DynamoCommit`] for more information + /// + /// This will use the same region, credentials and endpoint as configured for S3 + Dynamo(DynamoCommit), + + /// Disable `conditional put` + Disabled, +} + +impl std::fmt::Display for S3ConditionalPut { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ETagMatch => write!(f, "etag"), + Self::Dynamo(lock) => write!(f, "dynamo: {}", lock.table_name()), + Self::Disabled => write!(f, "disabled"), + } + } +} + +impl S3ConditionalPut { + fn from_str(s: &str) -> Option { + match s.trim() { + "etag" => Some(Self::ETagMatch), + "disabled" => Some(Self::Disabled), + trimmed => match trimmed.split_once(':')? { + ("dynamo", s) => Some(Self::Dynamo(DynamoCommit::from_str(s)?)), + _ => None, + }, + } + } +} + +impl Parse for S3ConditionalPut { + fn parse(v: &str) -> crate::Result { + Self::from_str(v).ok_or_else(|| crate::Error::Generic { + store: "Config", + source: format!("Failed to parse \"{v}\" as S3PutConditional").into(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::S3CopyIfNotExists; + use crate::aws::{DynamoCommit, S3ConditionalPut}; + + #[test] + fn parse_s3_copy_if_not_exists_header() { + let input = "header: cf-copy-destination-if-none-match: *"; + let expected = Some(S3CopyIfNotExists::Header( + "cf-copy-destination-if-none-match".to_owned(), + "*".to_owned(), + )); + + assert_eq!(expected, S3CopyIfNotExists::from_str(input)); + } + + #[test] + fn parse_s3_copy_if_not_exists_header_with_status() { + let input = "header-with-status:key:value:403"; + let expected = Some(S3CopyIfNotExists::HeaderWithStatus( + "key".to_owned(), + "value".to_owned(), + reqwest::StatusCode::FORBIDDEN, + )); + + assert_eq!(expected, S3CopyIfNotExists::from_str(input)); + } + + #[test] + fn parse_s3_copy_if_not_exists_dynamo() { + let input = "dynamo: table:100"; + let expected = Some(S3CopyIfNotExists::Dynamo( + DynamoCommit::new("table".into()).with_timeout(100), + )); + assert_eq!(expected, S3CopyIfNotExists::from_str(input)); + } + + #[test] + fn parse_s3_condition_put_dynamo() { + let input = "dynamo: table:1300"; + let expected = Some(S3ConditionalPut::Dynamo( + DynamoCommit::new("table".into()).with_timeout(1300), + )); + assert_eq!(expected, S3ConditionalPut::from_str(input)); + } + + #[test] + fn parse_s3_copy_if_not_exists_header_whitespace_invariant() { + let expected = Some(S3CopyIfNotExists::Header( + "cf-copy-destination-if-none-match".to_owned(), + "*".to_owned(), + )); + + const INPUTS: &[&str] = &[ + "header:cf-copy-destination-if-none-match:*", + "header: cf-copy-destination-if-none-match:*", + "header: cf-copy-destination-if-none-match: *", + "header : cf-copy-destination-if-none-match: *", + "header : cf-copy-destination-if-none-match : *", + "header : cf-copy-destination-if-none-match : * ", + ]; + + for input in INPUTS { + assert_eq!(expected, S3CopyIfNotExists::from_str(input)); + } + } + + #[test] + fn parse_s3_copy_if_not_exists_header_with_status_whitespace_invariant() { + let expected = Some(S3CopyIfNotExists::HeaderWithStatus( + "key".to_owned(), + "value".to_owned(), + reqwest::StatusCode::FORBIDDEN, + )); + + const INPUTS: &[&str] = &[ + "header-with-status:key:value:403", + "header-with-status: key:value:403", + "header-with-status: key: value:403", + "header-with-status: key: value: 403", + "header-with-status : key: value: 403", + "header-with-status : key : value: 403", + "header-with-status : key : value : 403", + "header-with-status : key : value : 403 ", + ]; + + for input in INPUTS { + assert_eq!(expected, S3CopyIfNotExists::from_str(input)); + } + } +} diff --git a/src/aws/resolve.rs b/src/aws/resolve.rs new file mode 100644 index 0000000..db899ea --- /dev/null +++ b/src/aws/resolve.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::STORE; +use crate::{ClientOptions, Result}; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Bucket '{}' not found", bucket)] + BucketNotFound { bucket: String }, + + #[error("Failed to resolve region for bucket '{}'", bucket)] + ResolveRegion { + bucket: String, + source: reqwest::Error, + }, + + #[error("Failed to parse the region for bucket '{}'", bucket)] + RegionParse { bucket: String }, +} + +impl From for crate::Error { + fn from(source: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(source), + } + } +} + +/// Get the bucket region using the [HeadBucket API]. This will fail if the bucket does not exist. +/// +/// [HeadBucket API]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadBucket.html +pub async fn resolve_bucket_region(bucket: &str, client_options: &ClientOptions) -> Result { + use reqwest::StatusCode; + + let endpoint = format!("https://{}.s3.amazonaws.com", bucket); + + let client = client_options.client()?; + + let response = client.head(&endpoint).send().await.map_err(|source| { + let bucket = bucket.into(); + Error::ResolveRegion { bucket, source } + })?; + + if response.status() == StatusCode::NOT_FOUND { + let bucket = bucket.into(); + return Err(Error::BucketNotFound { bucket }.into()); + } + + let region = response + .headers() + .get("x-amz-bucket-region") + .and_then(|x| x.to_str().ok()) + .ok_or_else(|| Error::RegionParse { + bucket: bucket.into(), + })?; + + Ok(region.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_bucket_does_not_exist() { + let bucket = "please-dont-exist"; + + let result = resolve_bucket_region(bucket, &ClientOptions::new()).await; + + assert!(result.is_err()); + } +} diff --git a/src/azure/builder.rs b/src/azure/builder.rs new file mode 100644 index 0000000..f176fc6 --- /dev/null +++ b/src/azure/builder.rs @@ -0,0 +1,1242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::azure::client::{AzureClient, AzureConfig}; +use crate::azure::credential::{ + AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, FabricTokenOAuthProvider, + ImdsManagedIdentityProvider, WorkloadIdentityOAuthProvider, +}; +use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE}; +use crate::client::{http_connector, HttpConnector, TokenCredentialProvider}; +use crate::config::ConfigValue; +use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider}; +use percent_encoding::percent_decode_str; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; +use std::sync::Arc; +use url::Url; + +/// The well-known account used by Azurite and the legacy Azure Storage Emulator. +/// +/// +const EMULATOR_ACCOUNT: &str = "devstoreaccount1"; + +/// The well-known account key used by Azurite and the legacy Azure Storage Emulator. +/// +/// +const EMULATOR_ACCOUNT_KEY: &str = + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="; + +const MSI_ENDPOINT_ENV_KEY: &str = "IDENTITY_ENDPOINT"; + +/// A specialized `Error` for Azure builder-related errors +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[error( + "Unable parse emulator url {}={}, Error: {}", + env_name, + env_value, + source + )] + UnableToParseEmulatorUrl { + env_name: String, + env_value: String, + source: url::ParseError, + }, + + #[error("Account must be specified")] + MissingAccount {}, + + #[error("Container name must be specified")] + MissingContainerName {}, + + #[error( + "Unknown url scheme cannot be parsed into storage location: {}", + scheme + )] + UnknownUrlScheme { scheme: String }, + + #[error("URL did not match any known pattern for scheme: {}", url)] + UrlNotRecognised { url: String }, + + #[error("Failed parsing an SAS key")] + DecodeSasKey { source: std::str::Utf8Error }, + + #[error("Missing component in SAS query pair")] + MissingSasComponent {}, + + #[error("Configuration key: '{}' is not known.", key)] + UnknownConfigurationKey { key: String }, +} + +impl From for crate::Error { + fn from(source: Error) -> Self { + match source { + Error::UnknownConfigurationKey { key } => { + Self::UnknownConfigurationKey { store: STORE, key } + } + _ => Self::Generic { + store: STORE, + source: Box::new(source), + }, + } + } +} + +/// Configure a connection to Microsoft Azure Blob Storage container using +/// the specified credentials. +/// +/// # Example +/// ``` +/// # let ACCOUNT = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY = "foo"; +/// # use object_store::azure::MicrosoftAzureBuilder; +/// let azure = MicrosoftAzureBuilder::new() +/// .with_account(ACCOUNT) +/// .with_access_key(ACCESS_KEY) +/// .with_container_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Default, Clone)] +pub struct MicrosoftAzureBuilder { + /// Account name + account_name: Option, + /// Access key + access_key: Option, + /// Container name + container_name: Option, + /// Bearer token + bearer_token: Option, + /// Client id + client_id: Option, + /// Client secret + client_secret: Option, + /// Tenant id + tenant_id: Option, + /// Query pairs for shared access signature authorization + sas_query_pairs: Option>, + /// Shared access signature + sas_key: Option, + /// Authority host + authority_host: Option, + /// Url + url: Option, + /// When set to true, azurite storage emulator has to be used + use_emulator: ConfigValue, + /// Storage endpoint + endpoint: Option, + /// Msi endpoint for acquiring managed identity token + msi_endpoint: Option, + /// Object id for use with managed identity authentication + object_id: Option, + /// Msi resource id for use with managed identity authentication + msi_resource_id: Option, + /// File containing token for Azure AD workload identity federation + federated_token_file: Option, + /// When set to true, azure cli has to be used for acquiring access token + use_azure_cli: ConfigValue, + /// Retry config + retry_config: RetryConfig, + /// Client options + client_options: ClientOptions, + /// Credentials + credentials: Option, + /// Skip signing requests + skip_signature: ConfigValue, + /// When set to true, fabric url scheme will be used + /// + /// i.e. https://{account_name}.dfs.fabric.microsoft.com + use_fabric_endpoint: ConfigValue, + /// When set to true, skips tagging objects + disable_tagging: ConfigValue, + /// Fabric token service url + fabric_token_service_url: Option, + /// Fabric workload host + fabric_workload_host: Option, + /// Fabric session token + fabric_session_token: Option, + /// Fabric cluster identifier + fabric_cluster_identifier: Option, + /// The [`HttpConnector`] to use + http_connector: Option>, +} + +/// Configuration keys for [`MicrosoftAzureBuilder`] +/// +/// Configuration via keys can be done via [`MicrosoftAzureBuilder::with_config`] +/// +/// # Example +/// ``` +/// # use object_store::azure::{MicrosoftAzureBuilder, AzureConfigKey}; +/// let builder = MicrosoftAzureBuilder::new() +/// .with_config("azure_client_id".parse().unwrap(), "my-client-id") +/// .with_config(AzureConfigKey::AuthorityId, "my-tenant-id"); +/// ``` +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)] +#[non_exhaustive] +pub enum AzureConfigKey { + /// The name of the azure storage account + /// + /// Supported keys: + /// - `azure_storage_account_name` + /// - `account_name` + AccountName, + + /// Master key for accessing storage account + /// + /// Supported keys: + /// - `azure_storage_account_key` + /// - `azure_storage_access_key` + /// - `azure_storage_master_key` + /// - `access_key` + /// - `account_key` + /// - `master_key` + AccessKey, + + /// Service principal client id for authorizing requests + /// + /// Supported keys: + /// - `azure_storage_client_id` + /// - `azure_client_id` + /// - `client_id` + ClientId, + + /// Service principal client secret for authorizing requests + /// + /// Supported keys: + /// - `azure_storage_client_secret` + /// - `azure_client_secret` + /// - `client_secret` + ClientSecret, + + /// Tenant id used in oauth flows + /// + /// Supported keys: + /// - `azure_storage_tenant_id` + /// - `azure_storage_authority_id` + /// - `azure_tenant_id` + /// - `azure_authority_id` + /// - `tenant_id` + /// - `authority_id` + AuthorityId, + + /// Authority host used in oauth flows + /// + /// Supported keys: + /// - `azure_storage_authority_host` + /// - `azure_authority_host` + /// - `authority_host` + AuthorityHost, + + /// Shared access signature. + /// + /// The signature is expected to be percent-encoded, much like they are provided + /// in the azure storage explorer or azure portal. + /// + /// Supported keys: + /// - `azure_storage_sas_key` + /// - `azure_storage_sas_token` + /// - `sas_key` + /// - `sas_token` + SasKey, + + /// Bearer token + /// + /// Supported keys: + /// - `azure_storage_token` + /// - `bearer_token` + /// - `token` + Token, + + /// Use object store with azurite storage emulator + /// + /// Supported keys: + /// - `azure_storage_use_emulator` + /// - `object_store_use_emulator` + /// - `use_emulator` + UseEmulator, + + /// Override the endpoint used to communicate with blob storage + /// + /// Supported keys: + /// - `azure_storage_endpoint` + /// - `azure_endpoint` + /// - `endpoint` + Endpoint, + + /// Use object store with url scheme account.dfs.fabric.microsoft.com + /// + /// Supported keys: + /// - `azure_use_fabric_endpoint` + /// - `use_fabric_endpoint` + UseFabricEndpoint, + + /// Endpoint to request a imds managed identity token + /// + /// Supported keys: + /// - `azure_msi_endpoint` + /// - `azure_identity_endpoint` + /// - `identity_endpoint` + /// - `msi_endpoint` + MsiEndpoint, + + /// Object id for use with managed identity authentication + /// + /// Supported keys: + /// - `azure_object_id` + /// - `object_id` + ObjectId, + + /// Msi resource id for use with managed identity authentication + /// + /// Supported keys: + /// - `azure_msi_resource_id` + /// - `msi_resource_id` + MsiResourceId, + + /// File containing token for Azure AD workload identity federation + /// + /// Supported keys: + /// - `azure_federated_token_file` + /// - `federated_token_file` + FederatedTokenFile, + + /// Use azure cli for acquiring access token + /// + /// Supported keys: + /// - `azure_use_azure_cli` + /// - `use_azure_cli` + UseAzureCli, + + /// Skip signing requests + /// + /// Supported keys: + /// - `azure_skip_signature` + /// - `skip_signature` + SkipSignature, + + /// Container name + /// + /// Supported keys: + /// - `azure_container_name` + /// - `container_name` + ContainerName, + + /// Disables tagging objects + /// + /// This can be desirable if not supported by the backing store + /// + /// Supported keys: + /// - `azure_disable_tagging` + /// - `disable_tagging` + DisableTagging, + + /// Fabric token service url + /// + /// Supported keys: + /// - `azure_fabric_token_service_url` + /// - `fabric_token_service_url` + FabricTokenServiceUrl, + + /// Fabric workload host + /// + /// Supported keys: + /// - `azure_fabric_workload_host` + /// - `fabric_workload_host` + FabricWorkloadHost, + + /// Fabric session token + /// + /// Supported keys: + /// - `azure_fabric_session_token` + /// - `fabric_session_token` + FabricSessionToken, + + /// Fabric cluster identifier + /// + /// Supported keys: + /// - `azure_fabric_cluster_identifier` + /// - `fabric_cluster_identifier` + FabricClusterIdentifier, + + /// Client options + Client(ClientConfigKey), +} + +impl AsRef for AzureConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::AccountName => "azure_storage_account_name", + Self::AccessKey => "azure_storage_account_key", + Self::ClientId => "azure_storage_client_id", + Self::ClientSecret => "azure_storage_client_secret", + Self::AuthorityId => "azure_storage_tenant_id", + Self::AuthorityHost => "azure_storage_authority_host", + Self::SasKey => "azure_storage_sas_key", + Self::Token => "azure_storage_token", + Self::UseEmulator => "azure_storage_use_emulator", + Self::UseFabricEndpoint => "azure_use_fabric_endpoint", + Self::Endpoint => "azure_storage_endpoint", + Self::MsiEndpoint => "azure_msi_endpoint", + Self::ObjectId => "azure_object_id", + Self::MsiResourceId => "azure_msi_resource_id", + Self::FederatedTokenFile => "azure_federated_token_file", + Self::UseAzureCli => "azure_use_azure_cli", + Self::SkipSignature => "azure_skip_signature", + Self::ContainerName => "azure_container_name", + Self::DisableTagging => "azure_disable_tagging", + Self::FabricTokenServiceUrl => "azure_fabric_token_service_url", + Self::FabricWorkloadHost => "azure_fabric_workload_host", + Self::FabricSessionToken => "azure_fabric_session_token", + Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier", + Self::Client(key) => key.as_ref(), + } + } +} + +impl FromStr for AzureConfigKey { + type Err = crate::Error; + + fn from_str(s: &str) -> Result { + match s { + "azure_storage_account_key" + | "azure_storage_access_key" + | "azure_storage_master_key" + | "master_key" + | "account_key" + | "access_key" => Ok(Self::AccessKey), + "azure_storage_account_name" | "account_name" => Ok(Self::AccountName), + "azure_storage_client_id" | "azure_client_id" | "client_id" => Ok(Self::ClientId), + "azure_storage_client_secret" | "azure_client_secret" | "client_secret" => { + Ok(Self::ClientSecret) + } + "azure_storage_tenant_id" + | "azure_storage_authority_id" + | "azure_tenant_id" + | "azure_authority_id" + | "tenant_id" + | "authority_id" => Ok(Self::AuthorityId), + "azure_storage_authority_host" | "azure_authority_host" | "authority_host" => { + Ok(Self::AuthorityHost) + } + "azure_storage_sas_key" | "azure_storage_sas_token" | "sas_key" | "sas_token" => { + Ok(Self::SasKey) + } + "azure_storage_token" | "bearer_token" | "token" => Ok(Self::Token), + "azure_storage_use_emulator" | "use_emulator" => Ok(Self::UseEmulator), + "azure_storage_endpoint" | "azure_endpoint" | "endpoint" => Ok(Self::Endpoint), + "azure_msi_endpoint" + | "azure_identity_endpoint" + | "identity_endpoint" + | "msi_endpoint" => Ok(Self::MsiEndpoint), + "azure_object_id" | "object_id" => Ok(Self::ObjectId), + "azure_msi_resource_id" | "msi_resource_id" => Ok(Self::MsiResourceId), + "azure_federated_token_file" | "federated_token_file" => Ok(Self::FederatedTokenFile), + "azure_use_fabric_endpoint" | "use_fabric_endpoint" => Ok(Self::UseFabricEndpoint), + "azure_use_azure_cli" | "use_azure_cli" => Ok(Self::UseAzureCli), + "azure_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), + "azure_container_name" | "container_name" => Ok(Self::ContainerName), + "azure_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging), + "azure_fabric_token_service_url" | "fabric_token_service_url" => { + Ok(Self::FabricTokenServiceUrl) + } + "azure_fabric_workload_host" | "fabric_workload_host" => Ok(Self::FabricWorkloadHost), + "azure_fabric_session_token" | "fabric_session_token" => Ok(Self::FabricSessionToken), + "azure_fabric_cluster_identifier" | "fabric_cluster_identifier" => { + Ok(Self::FabricClusterIdentifier) + } + // Backwards compatibility + "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), + _ => match s.strip_prefix("azure_").unwrap_or(s).parse() { + Ok(key) => Ok(Self::Client(key)), + Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), + }, + } + } +} + +impl std::fmt::Debug for MicrosoftAzureBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "MicrosoftAzureBuilder {{ account: {:?}, container_name: {:?} }}", + self.account_name, self.container_name + ) + } +} + +impl MicrosoftAzureBuilder { + /// Create a new [`MicrosoftAzureBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Create an instance of [`MicrosoftAzureBuilder`] with values pre-populated from environment variables. + /// + /// Variables extracted from environment: + /// * AZURE_STORAGE_ACCOUNT_NAME: storage account name + /// * AZURE_STORAGE_ACCOUNT_KEY: storage account master key + /// * AZURE_STORAGE_ACCESS_KEY: alias for AZURE_STORAGE_ACCOUNT_KEY + /// * AZURE_STORAGE_CLIENT_ID -> client id for service principal authorization + /// * AZURE_STORAGE_CLIENT_SECRET -> client secret for service principal authorization + /// * AZURE_STORAGE_TENANT_ID -> tenant id used in oauth flows + /// # Example + /// ``` + /// use object_store::azure::MicrosoftAzureBuilder; + /// + /// let azure = MicrosoftAzureBuilder::from_env() + /// .with_container_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder = Self::default(); + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if key.starts_with("AZURE_") { + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + builder = builder.with_config(config_key, value); + } + } + } + } + + if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) { + builder = builder.with_msi_endpoint(text); + } + + builder + } + + /// Parse available connection info form a well-known storage URL. + /// + /// The supported url schemes are: + /// + /// - `abfs[s]:///` (according to [fsspec](https://github.com/fsspec/adlfs)) + /// - `abfs[s]://@.dfs.core.windows.net/` + /// - `abfs[s]://@.dfs.fabric.microsoft.com/` + /// - `az:///` (according to [fsspec](https://github.com/fsspec/adlfs)) + /// - `adl:///` (according to [fsspec](https://github.com/fsspec/adlfs)) + /// - `azure:///` (custom) + /// - `https://.dfs.core.windows.net` + /// - `https://.blob.core.windows.net` + /// - `https://.blob.core.windows.net/` + /// - `https://.dfs.fabric.microsoft.com` + /// - `https://.dfs.fabric.microsoft.com/` + /// - `https://.blob.fabric.microsoft.com` + /// - `https://.blob.fabric.microsoft.com/` + /// + /// Note: Settings derived from the URL will override any others set on this builder + /// + /// # Example + /// ``` + /// use object_store::azure::MicrosoftAzureBuilder; + /// + /// let azure = MicrosoftAzureBuilder::from_env() + /// .with_url("abfss://file_system@account.dfs.core.windows.net/") + /// .build(); + /// ``` + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set an option on the builder via a key - value pair. + pub fn with_config(mut self, key: AzureConfigKey, value: impl Into) -> Self { + match key { + AzureConfigKey::AccessKey => self.access_key = Some(value.into()), + AzureConfigKey::AccountName => self.account_name = Some(value.into()), + AzureConfigKey::ClientId => self.client_id = Some(value.into()), + AzureConfigKey::ClientSecret => self.client_secret = Some(value.into()), + AzureConfigKey::AuthorityId => self.tenant_id = Some(value.into()), + AzureConfigKey::AuthorityHost => self.authority_host = Some(value.into()), + AzureConfigKey::SasKey => self.sas_key = Some(value.into()), + AzureConfigKey::Token => self.bearer_token = Some(value.into()), + AzureConfigKey::MsiEndpoint => self.msi_endpoint = Some(value.into()), + AzureConfigKey::ObjectId => self.object_id = Some(value.into()), + AzureConfigKey::MsiResourceId => self.msi_resource_id = Some(value.into()), + AzureConfigKey::FederatedTokenFile => self.federated_token_file = Some(value.into()), + AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value), + AzureConfigKey::SkipSignature => self.skip_signature.parse(value), + AzureConfigKey::UseEmulator => self.use_emulator.parse(value), + AzureConfigKey::Endpoint => self.endpoint = Some(value.into()), + AzureConfigKey::UseFabricEndpoint => self.use_fabric_endpoint.parse(value), + AzureConfigKey::Client(key) => { + self.client_options = self.client_options.with_config(key, value) + } + AzureConfigKey::ContainerName => self.container_name = Some(value.into()), + AzureConfigKey::DisableTagging => self.disable_tagging.parse(value), + AzureConfigKey::FabricTokenServiceUrl => { + self.fabric_token_service_url = Some(value.into()) + } + AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host = Some(value.into()), + AzureConfigKey::FabricSessionToken => self.fabric_session_token = Some(value.into()), + AzureConfigKey::FabricClusterIdentifier => { + self.fabric_cluster_identifier = Some(value.into()) + } + }; + self + } + + /// Get config value via a [`AzureConfigKey`]. + /// + /// # Example + /// ``` + /// use object_store::azure::{MicrosoftAzureBuilder, AzureConfigKey}; + /// + /// let builder = MicrosoftAzureBuilder::from_env() + /// .with_account("foo"); + /// let account_name = builder.get_config_value(&AzureConfigKey::AccountName).unwrap_or_default(); + /// assert_eq!("foo", &account_name); + /// ``` + pub fn get_config_value(&self, key: &AzureConfigKey) -> Option { + match key { + AzureConfigKey::AccountName => self.account_name.clone(), + AzureConfigKey::AccessKey => self.access_key.clone(), + AzureConfigKey::ClientId => self.client_id.clone(), + AzureConfigKey::ClientSecret => self.client_secret.clone(), + AzureConfigKey::AuthorityId => self.tenant_id.clone(), + AzureConfigKey::AuthorityHost => self.authority_host.clone(), + AzureConfigKey::SasKey => self.sas_key.clone(), + AzureConfigKey::Token => self.bearer_token.clone(), + AzureConfigKey::UseEmulator => Some(self.use_emulator.to_string()), + AzureConfigKey::UseFabricEndpoint => Some(self.use_fabric_endpoint.to_string()), + AzureConfigKey::Endpoint => self.endpoint.clone(), + AzureConfigKey::MsiEndpoint => self.msi_endpoint.clone(), + AzureConfigKey::ObjectId => self.object_id.clone(), + AzureConfigKey::MsiResourceId => self.msi_resource_id.clone(), + AzureConfigKey::FederatedTokenFile => self.federated_token_file.clone(), + AzureConfigKey::UseAzureCli => Some(self.use_azure_cli.to_string()), + AzureConfigKey::SkipSignature => Some(self.skip_signature.to_string()), + AzureConfigKey::Client(key) => self.client_options.get_config_value(key), + AzureConfigKey::ContainerName => self.container_name.clone(), + AzureConfigKey::DisableTagging => Some(self.disable_tagging.to_string()), + AzureConfigKey::FabricTokenServiceUrl => self.fabric_token_service_url.clone(), + AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host.clone(), + AzureConfigKey::FabricSessionToken => self.fabric_session_token.clone(), + AzureConfigKey::FabricClusterIdentifier => self.fabric_cluster_identifier.clone(), + } + } + + /// Sets properties on this builder based on a URL + /// + /// This is a separate member function to allow fallible computation to + /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] + fn parse_url(&mut self, url: &str) -> Result<()> { + let parsed = Url::parse(url).map_err(|source| { + let url = url.into(); + Error::UnableToParseUrl { url, source } + })?; + + let host = parsed + .host_str() + .ok_or_else(|| Error::UrlNotRecognised { url: url.into() })?; + + let validate = |s: &str| match s.contains('.') { + true => Err(Error::UrlNotRecognised { url: url.into() }), + false => Ok(s.to_string()), + }; + + match parsed.scheme() { + "az" | "adl" | "azure" => self.container_name = Some(validate(host)?), + "abfs" | "abfss" => { + // abfs(s) might refer to the fsspec convention abfs:/// + // or the convention for the hadoop driver abfs[s]://@.dfs.core.windows.net/ + if parsed.username().is_empty() { + self.container_name = Some(validate(host)?); + } else if let Some(a) = host.strip_suffix(".dfs.core.windows.net") { + self.container_name = Some(validate(parsed.username())?); + self.account_name = Some(validate(a)?); + } else if let Some(a) = host.strip_suffix(".dfs.fabric.microsoft.com") { + self.container_name = Some(validate(parsed.username())?); + self.account_name = Some(validate(a)?); + self.use_fabric_endpoint = true.into(); + } else { + return Err(Error::UrlNotRecognised { url: url.into() }.into()); + } + } + "https" => match host.split_once('.') { + Some((a, "dfs.core.windows.net")) | Some((a, "blob.core.windows.net")) => { + self.account_name = Some(validate(a)?); + if let Some(container) = parsed.path_segments().unwrap().next() { + self.container_name = Some(validate(container)?); + } + } + Some((a, "dfs.fabric.microsoft.com")) | Some((a, "blob.fabric.microsoft.com")) => { + self.account_name = Some(validate(a)?); + // Attempt to infer the container name from the URL + // - https://onelake.dfs.fabric.microsoft.com///Files/test.csv + // - https://onelake.dfs.fabric.microsoft.com//.// + // + // See + if let Some(workspace) = parsed.path_segments().unwrap().next() { + if !workspace.is_empty() { + self.container_name = Some(workspace.to_string()) + } + } + self.use_fabric_endpoint = true.into(); + } + _ => return Err(Error::UrlNotRecognised { url: url.into() }.into()), + }, + scheme => { + let scheme = scheme.into(); + return Err(Error::UnknownUrlScheme { scheme }.into()); + } + } + Ok(()) + } + + /// Set the Azure Account (required) + pub fn with_account(mut self, account: impl Into) -> Self { + self.account_name = Some(account.into()); + self + } + + /// Set the Azure Container Name (required) + pub fn with_container_name(mut self, container_name: impl Into) -> Self { + self.container_name = Some(container_name.into()); + self + } + + /// Set the Azure Access Key (required - one of access key, bearer token, or client credentials) + pub fn with_access_key(mut self, access_key: impl Into) -> Self { + self.access_key = Some(access_key.into()); + self + } + + /// Set a static bearer token to be used for authorizing requests + pub fn with_bearer_token_authorization(mut self, bearer_token: impl Into) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + /// Set a client secret used for client secret authorization + pub fn with_client_secret_authorization( + mut self, + client_id: impl Into, + client_secret: impl Into, + tenant_id: impl Into, + ) -> Self { + self.client_id = Some(client_id.into()); + self.client_secret = Some(client_secret.into()); + self.tenant_id = Some(tenant_id.into()); + self + } + + /// Sets the client id for use in client secret or k8s federated credential flow + pub fn with_client_id(mut self, client_id: impl Into) -> Self { + self.client_id = Some(client_id.into()); + self + } + + /// Sets the client secret for use in client secret flow + pub fn with_client_secret(mut self, client_secret: impl Into) -> Self { + self.client_secret = Some(client_secret.into()); + self + } + + /// Sets the tenant id for use in client secret or k8s federated credential flow + pub fn with_tenant_id(mut self, tenant_id: impl Into) -> Self { + self.tenant_id = Some(tenant_id.into()); + self + } + + /// Set query pairs appended to the url for shared access signature authorization + pub fn with_sas_authorization(mut self, query_pairs: impl Into>) -> Self { + self.sas_query_pairs = Some(query_pairs.into()); + self + } + + /// Set the credential provider overriding any other options + pub fn with_credentials(mut self, credentials: AzureCredentialProvider) -> Self { + self.credentials = Some(credentials); + self + } + + /// Set if the Azure emulator should be used (defaults to false) + pub fn with_use_emulator(mut self, use_emulator: bool) -> Self { + self.use_emulator = use_emulator.into(); + self + } + + /// Override the endpoint used to communicate with blob storage + /// + /// Defaults to `https://{account}.blob.core.windows.net` + /// + /// By default, only HTTPS schemes are enabled. To connect to an HTTP endpoint, enable + /// [`Self::with_allow_http`]. + pub fn with_endpoint(mut self, endpoint: String) -> Self { + self.endpoint = Some(endpoint); + self + } + + /// Set if Microsoft Fabric url scheme should be used (defaults to false) + /// + /// When disabled the url scheme used is `https://{account}.blob.core.windows.net` + /// When enabled the url scheme used is `https://{account}.dfs.fabric.microsoft.com` + /// + /// Note: [`Self::with_endpoint`] will take precedence over this option + pub fn with_use_fabric_endpoint(mut self, use_fabric_endpoint: bool) -> Self { + self.use_fabric_endpoint = use_fabric_endpoint.into(); + self + } + + /// Sets what protocol is allowed + /// + /// If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.client_options = self.client_options.with_allow_http(allow_http); + self + } + + /// Sets an alternative authority host for OAuth based authorization + /// + /// Common hosts for azure clouds are defined in [authority_hosts](crate::azure::authority_hosts). + /// + /// Defaults to + pub fn with_authority_host(mut self, authority_host: impl Into) -> Self { + self.authority_host = Some(authority_host.into()); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Set the proxy_url to be used by the underlying client + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_url(proxy_url); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate(mut self, proxy_ca_certificate: impl Into) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// Sets the endpoint for acquiring managed identity token + pub fn with_msi_endpoint(mut self, msi_endpoint: impl Into) -> Self { + self.msi_endpoint = Some(msi_endpoint.into()); + self + } + + /// Sets a file path for acquiring azure federated identity token in k8s + /// + /// requires `client_id` and `tenant_id` to be set + pub fn with_federated_token_file(mut self, federated_token_file: impl Into) -> Self { + self.federated_token_file = Some(federated_token_file.into()); + self + } + + /// Set if the Azure Cli should be used for acquiring access token + /// + /// + pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self { + self.use_azure_cli = use_azure_cli.into(); + self + } + + /// If enabled, [`MicrosoftAzure`] will not fetch credentials and will not sign requests + /// + /// This can be useful when interacting with public containers + pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { + self.skip_signature = skip_signature.into(); + self + } + + /// If set to `true` will ignore any tags provided to put_opts + pub fn with_disable_tagging(mut self, ignore: bool) -> Self { + self.disable_tagging = ignore.into(); + self + } + + /// The [`HttpConnector`] to use + /// + /// On non-WASM32 platforms uses [`reqwest`] by default, on WASM32 platforms must be provided + pub fn with_http_connector(mut self, connector: C) -> Self { + self.http_connector = Some(Arc::new(connector)); + self + } + + /// Configure a connection to container with given name on Microsoft Azure Blob store. + pub fn build(mut self) -> Result { + if let Some(url) = self.url.take() { + self.parse_url(&url)?; + } + + let container = self.container_name.ok_or(Error::MissingContainerName {})?; + + let static_creds = |credential: AzureCredential| -> AzureCredentialProvider { + Arc::new(StaticCredentialProvider::new(credential)) + }; + + let http = http_connector(self.http_connector)?; + + let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? { + let account_name = self + .account_name + .unwrap_or_else(|| EMULATOR_ACCOUNT.to_string()); + // Allow overriding defaults. Values taken from + // from https://docs.rs/azure_storage/0.2.0/src/azure_storage/core/clients/storage_account_client.rs.html#129-141 + let url = url_from_env("AZURITE_BLOB_STORAGE_URL", "http://127.0.0.1:10000")?; + let credential = if let Some(k) = self.access_key { + AzureCredential::AccessKey(AzureAccessKey::try_new(&k)?) + } else if let Some(bearer_token) = self.bearer_token { + AzureCredential::BearerToken(bearer_token) + } else if let Some(query_pairs) = self.sas_query_pairs { + AzureCredential::SASToken(query_pairs) + } else if let Some(sas) = self.sas_key { + AzureCredential::SASToken(split_sas(&sas)?) + } else { + AzureCredential::AccessKey(AzureAccessKey::try_new(EMULATOR_ACCOUNT_KEY)?) + }; + + self.client_options = self.client_options.with_allow_http(true); + (true, url, static_creds(credential), account_name) + } else { + let account_name = self.account_name.ok_or(Error::MissingAccount {})?; + let account_url = match self.endpoint { + Some(account_url) => account_url, + None => match self.use_fabric_endpoint.get()? { + true => { + format!("https://{}.blob.fabric.microsoft.com", &account_name) + } + false => format!("https://{}.blob.core.windows.net", &account_name), + }, + }; + + let url = Url::parse(&account_url).map_err(|source| { + let url = account_url.clone(); + Error::UnableToParseUrl { url, source } + })?; + + let credential = if let Some(credential) = self.credentials { + credential + } else if let ( + Some(fabric_token_service_url), + Some(fabric_workload_host), + Some(fabric_session_token), + Some(fabric_cluster_identifier), + ) = ( + &self.fabric_token_service_url, + &self.fabric_workload_host, + &self.fabric_session_token, + &self.fabric_cluster_identifier, + ) { + // This case should precede the bearer token case because it is more specific and will utilize the bearer token. + let fabric_credential = FabricTokenOAuthProvider::new( + fabric_token_service_url, + fabric_workload_host, + fabric_session_token, + fabric_cluster_identifier, + self.bearer_token.clone(), + ); + Arc::new(TokenCredentialProvider::new( + fabric_credential, + http.connect(&self.client_options)?, + self.retry_config.clone(), + )) as _ + } else if let Some(bearer_token) = self.bearer_token { + static_creds(AzureCredential::BearerToken(bearer_token)) + } else if let Some(access_key) = self.access_key { + let key = AzureAccessKey::try_new(&access_key)?; + static_creds(AzureCredential::AccessKey(key)) + } else if let (Some(client_id), Some(tenant_id), Some(federated_token_file)) = + (&self.client_id, &self.tenant_id, self.federated_token_file) + { + let client_credential = WorkloadIdentityOAuthProvider::new( + client_id, + federated_token_file, + tenant_id, + self.authority_host, + ); + Arc::new(TokenCredentialProvider::new( + client_credential, + http.connect(&self.client_options)?, + self.retry_config.clone(), + )) as _ + } else if let (Some(client_id), Some(client_secret), Some(tenant_id)) = + (&self.client_id, self.client_secret, &self.tenant_id) + { + let client_credential = ClientSecretOAuthProvider::new( + client_id.clone(), + client_secret, + tenant_id, + self.authority_host, + ); + Arc::new(TokenCredentialProvider::new( + client_credential, + http.connect(&self.client_options)?, + self.retry_config.clone(), + )) as _ + } else if let Some(query_pairs) = self.sas_query_pairs { + static_creds(AzureCredential::SASToken(query_pairs)) + } else if let Some(sas) = self.sas_key { + static_creds(AzureCredential::SASToken(split_sas(&sas)?)) + } else if self.use_azure_cli.get()? { + Arc::new(AzureCliCredential::new()) as _ + } else { + let msi_credential = ImdsManagedIdentityProvider::new( + self.client_id, + self.object_id, + self.msi_resource_id, + self.msi_endpoint, + ); + Arc::new(TokenCredentialProvider::new( + msi_credential, + http.connect(&self.client_options.metadata_options())?, + self.retry_config.clone(), + )) as _ + }; + (false, url, credential, account_name) + }; + + let config = AzureConfig { + account, + is_emulator, + skip_signature: self.skip_signature.get()?, + container, + disable_tagging: self.disable_tagging.get()?, + retry_config: self.retry_config, + client_options: self.client_options, + service: storage_url, + credentials: auth, + }; + + let http_client = http.connect(&config.client_options)?; + let client = Arc::new(AzureClient::new(config, http_client)); + + Ok(MicrosoftAzure { client }) + } +} + +/// Parses the contents of the environment variable `env_name` as a URL +/// if present, otherwise falls back to default_url +fn url_from_env(env_name: &str, default_url: &str) -> Result { + let url = match std::env::var(env_name) { + Ok(env_value) => { + Url::parse(&env_value).map_err(|source| Error::UnableToParseEmulatorUrl { + env_name: env_name.into(), + env_value, + source, + })? + } + Err(_) => Url::parse(default_url).expect("Failed to parse default URL"), + }; + Ok(url) +} + +fn split_sas(sas: &str) -> Result, Error> { + let sas = percent_decode_str(sas) + .decode_utf8() + .map_err(|source| Error::DecodeSasKey { source })?; + let kv_str_pairs = sas + .trim_start_matches('?') + .split('&') + .filter(|s| !s.chars().all(char::is_whitespace)); + let mut pairs = Vec::new(); + for kv_pair_str in kv_str_pairs { + let (k, v) = kv_pair_str + .trim() + .split_once('=') + .ok_or(Error::MissingSasComponent {})?; + pairs.push((k.into(), v.into())) + } + Ok(pairs) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn azure_blob_test_urls() { + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("abfss://file_system@account.dfs.core.windows.net/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, Some("file_system".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("abfss://file_system@account.dfs.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, Some("file_system".to_string())); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder.parse_url("abfs://container/path").unwrap(); + assert_eq!(builder.container_name, Some("container".to_string())); + + let mut builder = MicrosoftAzureBuilder::new(); + builder.parse_url("az://container").unwrap(); + assert_eq!(builder.container_name, Some("container".to_string())); + + let mut builder = MicrosoftAzureBuilder::new(); + builder.parse_url("az://container/path").unwrap(); + assert_eq!(builder.container_name, Some("container".to_string())); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.core.windows.net/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.core.windows.net/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.core.windows.net/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, Some("container".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, None); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.fabric.microsoft.com/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name.as_deref(), Some("container")); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, None); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.fabric.microsoft.com/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name.as_deref(), Some("container")); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let err_cases = [ + "mailto://account.blob.core.windows.net/", + "az://blob.mydomain/", + "abfs://container.foo/path", + "abfss://file_system@account.foo.dfs.core.windows.net/", + "abfss://file_system.bar@account.dfs.core.windows.net/", + "https://blob.mydomain/", + "https://blob.foo.dfs.core.windows.net/", + ]; + let mut builder = MicrosoftAzureBuilder::new(); + for case in err_cases { + builder.parse_url(case).unwrap_err(); + } + } + + #[test] + fn azure_test_config_from_map() { + let azure_client_id = "object_store:fake_access_key_id"; + let azure_storage_account_name = "object_store:fake_secret_key"; + let azure_storage_token = "object_store:fake_default_region"; + let options = HashMap::from([ + ("azure_client_id", azure_client_id), + ("azure_storage_account_name", azure_storage_account_name), + ("azure_storage_token", azure_storage_token), + ]); + + let builder = options + .into_iter() + .fold(MicrosoftAzureBuilder::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }); + assert_eq!(builder.client_id.unwrap(), azure_client_id); + assert_eq!(builder.account_name.unwrap(), azure_storage_account_name); + assert_eq!(builder.bearer_token.unwrap(), azure_storage_token); + } + + #[test] + fn azure_test_split_sas() { + let raw_sas = "?sv=2021-10-04&st=2023-01-04T17%3A48%3A57Z&se=2023-01-04T18%3A15%3A00Z&sr=c&sp=rcwl&sig=C7%2BZeEOWbrxPA3R0Cw%2Fw1EZz0%2B4KBvQexeKZKe%2BB6h0%3D"; + let expected = vec![ + ("sv".to_string(), "2021-10-04".to_string()), + ("st".to_string(), "2023-01-04T17:48:57Z".to_string()), + ("se".to_string(), "2023-01-04T18:15:00Z".to_string()), + ("sr".to_string(), "c".to_string()), + ("sp".to_string(), "rcwl".to_string()), + ( + "sig".to_string(), + "C7+ZeEOWbrxPA3R0Cw/w1EZz0+4KBvQexeKZKe+B6h0=".to_string(), + ), + ]; + let pairs = split_sas(raw_sas).unwrap(); + assert_eq!(expected, pairs); + } + + #[test] + fn azure_test_client_opts() { + let key = "AZURE_PROXY_URL"; + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + assert_eq!( + AzureConfigKey::Client(ClientConfigKey::ProxyUrl), + config_key + ); + } else { + panic!("{} not propagated as ClientConfigKey", key); + } + } +} diff --git a/src/azure/client.rs b/src/azure/client.rs new file mode 100644 index 0000000..dbeae63 --- /dev/null +++ b/src/azure/client.rs @@ -0,0 +1,1530 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::credential::AzureCredential; +use crate::azure::credential::*; +use crate::azure::{AzureCredentialProvider, STORE}; +use crate::client::builder::HttpRequestBuilder; +use crate::client::get::GetClient; +use crate::client::header::{get_put_result, HeaderConfig}; +use crate::client::list::ListClient; +use crate::client::retry::RetryExt; +use crate::client::{GetOptionsExt, HttpClient, HttpError, HttpRequest, HttpResponse}; +use crate::multipart::PartId; +use crate::path::DELIMITER; +use crate::util::{deserialize_rfc1123, GetRange}; +use crate::{ + Attribute, Attributes, ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutMode, + PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, RetryConfig, TagSet, +}; +use async_trait::async_trait; +use base64::prelude::{BASE64_STANDARD, BASE64_STANDARD_NO_PAD}; +use base64::Engine; +use bytes::{Buf, Bytes}; +use chrono::{DateTime, Utc}; +use http::{ + header::{HeaderMap, HeaderValue, CONTENT_LENGTH, CONTENT_TYPE, IF_MATCH, IF_NONE_MATCH}, + HeaderName, Method, +}; +use rand::Rng as _; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use url::Url; + +const VERSION_HEADER: &str = "x-ms-version-id"; +const USER_DEFINED_METADATA_HEADER_PREFIX: &str = "x-ms-meta-"; +static MS_CACHE_CONTROL: HeaderName = HeaderName::from_static("x-ms-blob-cache-control"); +static MS_CONTENT_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-content-type"); +static MS_CONTENT_DISPOSITION: HeaderName = + HeaderName::from_static("x-ms-blob-content-disposition"); +static MS_CONTENT_ENCODING: HeaderName = HeaderName::from_static("x-ms-blob-content-encoding"); +static MS_CONTENT_LANGUAGE: HeaderName = HeaderName::from_static("x-ms-blob-content-language"); + +static TAGS_HEADER: HeaderName = HeaderName::from_static("x-ms-tags"); + +/// A specialized `Error` for object store-related errors +#[derive(Debug, thiserror::Error)] +pub(crate) enum Error { + #[error("Error performing get request {}: {}", path, source)] + GetRequest { + source: crate::client::retry::RetryError, + path: String, + }, + + #[error("Error performing put request {}: {}", path, source)] + PutRequest { + source: crate::client::retry::RetryError, + path: String, + }, + + #[error("Error performing delete request {}: {}", path, source)] + DeleteRequest { + source: crate::client::retry::RetryError, + path: String, + }, + + #[error("Error performing bulk delete request: {}", source)] + BulkDeleteRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error receiving bulk delete request body: {}", source)] + BulkDeleteRequestBody { source: HttpError }, + + #[error( + "Bulk delete request failed due to invalid input: {} (code: {})", + reason, + code + )] + BulkDeleteRequestInvalidInput { code: String, reason: String }, + + #[error("Got invalid bulk delete response: {}", reason)] + InvalidBulkDeleteResponse { reason: String }, + + #[error( + "Bulk delete request failed for key {}: {} (code: {})", + path, + reason, + code + )] + DeleteFailed { + path: String, + code: String, + reason: String, + }, + + #[error("Error performing list request: {}", source)] + ListRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting list response body: {}", source)] + ListResponseBody { source: HttpError }, + + #[error("Got invalid list response: {}", source)] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[error("Unable to extract metadata from headers: {}", source)] + Metadata { + source: crate::client::header::Error, + }, + + #[error("ETag required for conditional update")] + MissingETag, + + #[error("Error requesting user delegation key: {}", source)] + DelegationKeyRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting user delegation key response body: {}", source)] + DelegationKeyResponseBody { source: HttpError }, + + #[error("Got invalid user delegation key response: {}", source)] + DelegationKeyResponse { source: quick_xml::de::DeError }, + + #[error("Generating SAS keys with SAS tokens auth is not supported")] + SASforSASNotSupported, + + #[error("Generating SAS keys while skipping signatures is not supported")] + SASwithSkipSignature, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } + | Error::DeleteRequest { source, path } + | Error::PutRequest { source, path } => source.error(STORE, path), + _ => Self::Generic { + store: STORE, + source: Box::new(err), + }, + } + } +} + +/// Configuration for [AzureClient] +#[derive(Debug)] +pub(crate) struct AzureConfig { + pub account: String, + pub container: String, + pub credentials: AzureCredentialProvider, + pub retry_config: RetryConfig, + pub service: Url, + pub is_emulator: bool, + pub skip_signature: bool, + pub disable_tagging: bool, + pub client_options: ClientOptions, +} + +impl AzureConfig { + pub(crate) fn path_url(&self, path: &Path) -> Url { + let mut url = self.service.clone(); + { + let mut path_mut = url.path_segments_mut().unwrap(); + if self.is_emulator { + path_mut.push(&self.account); + } + path_mut.push(&self.container).extend(path.parts()); + } + url + } + async fn get_credential(&self) -> Result>> { + if self.skip_signature { + Ok(None) + } else { + Some(self.credentials.get_credential().await).transpose() + } + } +} + +/// A builder for a put request allowing customisation of the headers and query string +struct PutRequest<'a> { + path: &'a Path, + config: &'a AzureConfig, + payload: PutPayload, + builder: HttpRequestBuilder, + idempotent: bool, +} + +impl PutRequest<'_> { + fn header(self, k: &HeaderName, v: &str) -> Self { + let builder = self.builder.header(k, v); + Self { builder, ..self } + } + + fn query(self, query: &T) -> Self { + let builder = self.builder.query(query); + Self { builder, ..self } + } + + fn idempotent(self, idempotent: bool) -> Self { + Self { idempotent, ..self } + } + + fn with_tags(mut self, tags: TagSet) -> Self { + let tags = tags.encoded(); + if !tags.is_empty() && !self.config.disable_tagging { + self.builder = self.builder.header(&TAGS_HEADER, tags); + } + self + } + + fn with_attributes(self, attributes: Attributes) -> Self { + let mut builder = self.builder; + let mut has_content_type = false; + for (k, v) in &attributes { + builder = match k { + Attribute::CacheControl => builder.header(&MS_CACHE_CONTROL, v.as_ref()), + Attribute::ContentDisposition => { + builder.header(&MS_CONTENT_DISPOSITION, v.as_ref()) + } + Attribute::ContentEncoding => builder.header(&MS_CONTENT_ENCODING, v.as_ref()), + Attribute::ContentLanguage => builder.header(&MS_CONTENT_LANGUAGE, v.as_ref()), + Attribute::ContentType => { + has_content_type = true; + builder.header(&MS_CONTENT_TYPE, v.as_ref()) + } + Attribute::Metadata(k_suffix) => builder.header( + &format!("{}{}", USER_DEFINED_METADATA_HEADER_PREFIX, k_suffix), + v.as_ref(), + ), + }; + } + + if !has_content_type { + if let Some(value) = self.config.client_options.get_content_type(self.path) { + builder = builder.header(&MS_CONTENT_TYPE, value); + } + } + Self { builder, ..self } + } + + fn with_extensions(self, extensions: ::http::Extensions) -> Self { + let builder = self.builder.extensions(extensions); + Self { builder, ..self } + } + + async fn send(self) -> Result { + let credential = self.config.get_credential().await?; + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + let response = self + .builder + .header(CONTENT_LENGTH, self.payload.content_length()) + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .idempotent(self.idempotent) + .payload(Some(self.payload)) + .send() + .await + .map_err(|source| { + let path = self.path.as_ref().into(); + Error::PutRequest { path, source } + })?; + + Ok(response) + } +} + +#[inline] +fn extend(dst: &mut Vec, data: &[u8]) { + dst.extend_from_slice(data); +} + +// Write header names as title case. The header name is assumed to be ASCII. +// We need it because Azure is not always treating headers as case insensitive. +fn title_case(dst: &mut Vec, name: &[u8]) { + dst.reserve(name.len()); + + // Ensure first character is uppercased + let mut prev = b'-'; + for &(mut c) in name { + if prev == b'-' { + c.make_ascii_uppercase(); + } + dst.push(c); + prev = c; + } +} + +fn write_headers(headers: &HeaderMap, dst: &mut Vec) { + for (name, value) in headers { + // We need special case handling here otherwise Azure returns 400 + // due to `Content-Id` instead of `Content-ID` + if name == "content-id" { + extend(dst, b"Content-ID"); + } else { + title_case(dst, name.as_str().as_bytes()); + } + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } +} + +// https://docs.oasis-open.org/odata/odata/v4.0/errata02/os/complete/part1-protocol/odata-v4.0-errata02-os-part1-protocol-complete.html#_Toc406398359 +fn serialize_part_delete_request( + dst: &mut Vec, + boundary: &str, + idx: usize, + request: HttpRequest, + relative_url: String, +) { + // Encode start marker for part + extend(dst, b"--"); + extend(dst, boundary.as_bytes()); + extend(dst, b"\r\n"); + + // Encode part headers + let mut part_headers = HeaderMap::new(); + part_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/http")); + part_headers.insert( + "Content-Transfer-Encoding", + HeaderValue::from_static("binary"), + ); + // Azure returns 400 if we send `Content-Id` instead of `Content-ID` + part_headers.insert("Content-ID", HeaderValue::from(idx)); + write_headers(&part_headers, dst); + extend(dst, b"\r\n"); + + // Encode the subrequest request-line + extend(dst, b"DELETE "); + extend(dst, format!("/{} ", relative_url).as_bytes()); + extend(dst, b"HTTP/1.1"); + extend(dst, b"\r\n"); + + // Encode subrequest headers + write_headers(request.headers(), dst); + extend(dst, b"\r\n"); + extend(dst, b"\r\n"); +} + +fn parse_multipart_response_boundary(response: &HttpResponse) -> Result { + let invalid_response = |msg: &str| Error::InvalidBulkDeleteResponse { + reason: msg.to_string(), + }; + + let content_type = response + .headers() + .get(CONTENT_TYPE) + .ok_or_else(|| invalid_response("missing Content-Type"))?; + + let boundary = content_type + .as_ref() + .strip_prefix(b"multipart/mixed; boundary=") + .ok_or_else(|| invalid_response("invalid Content-Type value"))? + .to_vec(); + + let boundary = + String::from_utf8(boundary).map_err(|_| invalid_response("invalid multipart boundary"))?; + + Ok(boundary) +} + +fn invalid_response(msg: &str) -> Error { + Error::InvalidBulkDeleteResponse { + reason: msg.to_string(), + } +} + +#[derive(Debug)] +struct MultipartField { + headers: HeaderMap, + content: Bytes, +} + +fn parse_multipart_body_fields(body: Bytes, boundary: &[u8]) -> Result> { + let start_marker = [b"--", boundary, b"\r\n"].concat(); + let next_marker = &start_marker[..start_marker.len() - 2]; + let end_marker = [b"--", boundary, b"--\r\n"].concat(); + + // There should be at most 256 responses per batch + let mut fields = Vec::with_capacity(256); + let mut remaining: &[u8] = body.as_ref(); + loop { + remaining = remaining + .strip_prefix(start_marker.as_slice()) + .ok_or_else(|| invalid_response("missing start marker for field"))?; + + // The documentation only mentions two headers for fields, we leave some extra margin + let mut scratch = [httparse::EMPTY_HEADER; 10]; + let mut headers = HeaderMap::new(); + match httparse::parse_headers(remaining, &mut scratch) { + Ok(httparse::Status::Complete((pos, headers_slice))) => { + remaining = &remaining[pos..]; + for header in headers_slice { + headers.insert( + HeaderName::from_bytes(header.name.as_bytes()).expect("valid"), + HeaderValue::from_bytes(header.value).expect("valid"), + ); + } + } + _ => return Err(invalid_response("unable to parse field headers").into()), + }; + + let next_pos = remaining + .windows(next_marker.len()) + .position(|window| window == next_marker) + .ok_or_else(|| invalid_response("early EOF while seeking to next boundary"))?; + + fields.push(MultipartField { + headers, + content: body.slice_ref(&remaining[..next_pos]), + }); + + remaining = &remaining[next_pos..]; + + // Support missing final CRLF + if remaining == end_marker || remaining == &end_marker[..end_marker.len() - 2] { + break; + } + } + Ok(fields) +} + +async fn parse_blob_batch_delete_body( + batch_body: Bytes, + boundary: String, + paths: &[Path], +) -> Result>> { + let mut results: Vec> = paths.iter().cloned().map(Ok).collect(); + + for field in parse_multipart_body_fields(batch_body, boundary.as_bytes())? { + let id = field + .headers + .get("content-id") + .and_then(|v| std::str::from_utf8(v.as_bytes()).ok()) + .and_then(|v| v.parse::().ok()); + + // Parse part response headers + // Documentation mentions 5 headers and states that other standard HTTP headers + // may be provided, in order to not incurr in more complexity to support an arbitrary + // amount of headers we chose a conservative amount and error otherwise + // https://learn.microsoft.com/en-us/rest/api/storageservices/delete-blob?tabs=microsoft-entra-id#response-headers + let mut headers = [httparse::EMPTY_HEADER; 48]; + let mut part_response = httparse::Response::new(&mut headers); + match part_response.parse(&field.content) { + Ok(httparse::Status::Complete(_)) => {} + _ => return Err(invalid_response("unable to parse response").into()), + }; + + match (id, part_response.code) { + (Some(_id), Some(code)) if (200..300).contains(&code) => {} + (Some(id), Some(404)) => { + results[id] = Err(crate::Error::NotFound { + path: paths[id].as_ref().to_string(), + source: Error::DeleteFailed { + path: paths[id].as_ref().to_string(), + code: 404.to_string(), + reason: part_response.reason.unwrap_or_default().to_string(), + } + .into(), + }); + } + (Some(id), Some(code)) => { + results[id] = Err(Error::DeleteFailed { + path: paths[id].as_ref().to_string(), + code: code.to_string(), + reason: part_response.reason.unwrap_or_default().to_string(), + } + .into()); + } + (None, Some(code)) => { + return Err(Error::BulkDeleteRequestInvalidInput { + code: code.to_string(), + reason: part_response.reason.unwrap_or_default().to_string(), + } + .into()) + } + _ => return Err(invalid_response("missing part response status code").into()), + } + } + + Ok(results) +} + +#[derive(Debug)] +pub(crate) struct AzureClient { + config: AzureConfig, + client: HttpClient, +} + +impl AzureClient { + /// create a new instance of [AzureClient] + pub(crate) fn new(config: AzureConfig, client: HttpClient) -> Self { + Self { config, client } + } + + /// Returns the config + pub(crate) fn config(&self) -> &AzureConfig { + &self.config + } + + async fn get_credential(&self) -> Result>> { + self.config.get_credential().await + } + + fn put_request<'a>(&'a self, path: &'a Path, payload: PutPayload) -> PutRequest<'a> { + let url = self.config.path_url(path); + let builder = self.client.request(Method::PUT, url.as_str()); + + PutRequest { + path, + builder, + payload, + config: &self.config, + idempotent: false, + } + } + + /// Make an Azure PUT request + pub(crate) async fn put_blob( + &self, + path: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let PutOptions { + mode, + tags, + attributes, + extensions, + } = opts; + + let builder = self + .put_request(path, payload) + .with_attributes(attributes) + .with_extensions(extensions) + .with_tags(tags); + + let builder = match &mode { + PutMode::Overwrite => builder.idempotent(true), + PutMode::Create => builder.header(&IF_NONE_MATCH, "*"), + PutMode::Update(v) => { + let etag = v.e_tag.as_ref().ok_or(Error::MissingETag)?; + builder.header(&IF_MATCH, etag) + } + }; + + let response = builder.header(&BLOB_TYPE, "BlockBlob").send().await?; + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) + } + + /// PUT a block + pub(crate) async fn put_block( + &self, + path: &Path, + _part_idx: usize, + payload: PutPayload, + ) -> Result { + let part_idx = u128::from_be_bytes(rand::thread_rng().gen()); + let content_id = format!("{part_idx:032x}"); + let block_id = BASE64_STANDARD.encode(&content_id); + + self.put_request(path, payload) + .query(&[("comp", "block"), ("blockid", &block_id)]) + .idempotent(true) + .send() + .await?; + + Ok(PartId { content_id }) + } + + /// PUT a block list + pub(crate) async fn put_block_list( + &self, + path: &Path, + parts: Vec, + opts: PutMultipartOpts, + ) -> Result { + let PutMultipartOpts { + tags, + attributes, + extensions, + } = opts; + + let blocks = parts + .into_iter() + .map(|part| BlockId::from(part.content_id)) + .collect(); + + let payload = BlockList { blocks }.to_xml().into(); + let response = self + .put_request(path, payload) + .with_attributes(attributes) + .with_tags(tags) + .with_extensions(extensions) + .query(&[("comp", "blocklist")]) + .idempotent(true) + .send() + .await?; + + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) + } + + /// Make an Azure Delete request + pub(crate) async fn delete_request( + &self, + path: &Path, + query: &T, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + self.client + .delete(url.as_str()) + .query(query) + .header(&DELETE_SNAPSHOTS, "include") + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .send() + .await + .map_err(|source| { + let path = path.as_ref().into(); + Error::DeleteRequest { source, path } + })?; + + Ok(()) + } + + fn build_bulk_delete_body( + &self, + boundary: &str, + paths: &[Path], + credential: &Option>, + ) -> Vec { + let mut body_bytes = Vec::with_capacity(paths.len() * 2048); + + for (idx, path) in paths.iter().enumerate() { + let url = self.config.path_url(path); + + // Build subrequest with proper authorization + let request = self + .client + .delete(url.as_str()) + .header(CONTENT_LENGTH, HeaderValue::from(0)) + // Each subrequest must be authorized individually [1] and we use + // the CredentialExt for this. + // [1]: https://learn.microsoft.com/en-us/rest/api/storageservices/blob-batch?tabs=microsoft-entra-id#request-body + .with_azure_authorization(credential, &self.config.account) + .into_parts() + .1 + .unwrap(); + + let url: Url = request.uri().to_string().parse().unwrap(); + + // Url for part requests must be relative and without base + let relative_url = self.config.service.make_relative(&url).unwrap(); + + serialize_part_delete_request(&mut body_bytes, boundary, idx, request, relative_url) + } + + // Encode end marker + extend(&mut body_bytes, b"--"); + extend(&mut body_bytes, boundary.as_bytes()); + extend(&mut body_bytes, b"--"); + extend(&mut body_bytes, b"\r\n"); + body_bytes + } + + pub(crate) async fn bulk_delete_request(&self, paths: Vec) -> Result>> { + if paths.is_empty() { + return Ok(Vec::new()); + } + + let credential = self.get_credential().await?; + + // https://www.ietf.org/rfc/rfc2046 + let random_bytes = rand::random::<[u8; 16]>(); // 128 bits + let boundary = format!("batch_{}", BASE64_STANDARD_NO_PAD.encode(random_bytes)); + + let body_bytes = self.build_bulk_delete_body(&boundary, &paths, &credential); + + // Send multipart request + let url = self.config.path_url(&Path::from("/")); + let batch_response = self + .client + .post(url.as_str()) + .query(&[("restype", "container"), ("comp", "batch")]) + .header( + CONTENT_TYPE, + HeaderValue::from_str(format!("multipart/mixed; boundary={}", boundary).as_str()) + .unwrap(), + ) + .header(CONTENT_LENGTH, HeaderValue::from(body_bytes.len())) + .body(body_bytes) + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| Error::BulkDeleteRequest { source })?; + + let boundary = parse_multipart_response_boundary(&batch_response)?; + + let batch_body = batch_response + .into_body() + .bytes() + .await + .map_err(|source| Error::BulkDeleteRequestBody { source })?; + + let results = parse_blob_batch_delete_body(batch_body, boundary, &paths).await?; + + Ok(results) + } + + /// Make an Azure Copy request + pub(crate) async fn copy_request(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(to); + let mut source = self.config.path_url(from); + + // If using SAS authorization must include the headers in the URL + // + if let Some(AzureCredential::SASToken(pairs)) = credential.as_deref() { + source.query_pairs_mut().extend_pairs(pairs); + } + + let mut builder = self + .client + .request(Method::PUT, url.as_str()) + .header(©_SOURCE, source.to_string()) + .header(CONTENT_LENGTH, HeaderValue::from_static("0")); + + if !overwrite { + builder = builder.header(IF_NONE_MATCH, "*"); + } + + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + builder + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .idempotent(overwrite) + .send() + .await + .map_err(|err| err.error(STORE, from.to_string()))?; + + Ok(()) + } + + /// Make a Get User Delegation Key request + /// + async fn get_user_delegation_key( + &self, + start: &DateTime, + end: &DateTime, + ) -> Result { + let credential = self.get_credential().await?; + let url = self.config.service.clone(); + + let start = start.to_rfc3339_opts(chrono::SecondsFormat::Secs, true); + let expiry = end.to_rfc3339_opts(chrono::SecondsFormat::Secs, true); + + let mut body = String::new(); + body.push_str("\n\n"); + body.push_str(&format!( + "\t{start}\n\t{expiry}\n" + )); + body.push_str(""); + + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + + let response = self + .client + .post(url.as_str()) + .body(body) + .query(&[("restype", "service"), ("comp", "userdelegationkey")]) + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .idempotent(true) + .send() + .await + .map_err(|source| Error::DelegationKeyRequest { source })? + .into_body() + .bytes() + .await + .map_err(|source| Error::DelegationKeyResponseBody { source })?; + + let response: UserDelegationKey = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::DelegationKeyResponse { source })?; + + Ok(response) + } + + /// Creat an AzureSigner for generating SAS tokens (pre-signed urls). + /// + /// Depending on the type of credential, this will either use the account key or a user delegation key. + /// Since delegation keys are acquired ad-hoc, the signer aloows for signing multiple urls with the same key. + pub(crate) async fn signer(&self, expires_in: Duration) -> Result { + let credential = self.get_credential().await?; + let signed_start = chrono::Utc::now(); + let signed_expiry = signed_start + expires_in; + match credential.as_deref() { + Some(AzureCredential::BearerToken(_)) => { + let key = self + .get_user_delegation_key(&signed_start, &signed_expiry) + .await?; + let signing_key = AzureAccessKey::try_new(&key.value)?; + Ok(AzureSigner::new( + signing_key, + self.config.account.clone(), + signed_start, + signed_expiry, + Some(key), + )) + } + Some(AzureCredential::AccessKey(key)) => Ok(AzureSigner::new( + key.to_owned(), + self.config.account.clone(), + signed_start, + signed_expiry, + None, + )), + None => Err(Error::SASwithSkipSignature.into()), + _ => Err(Error::SASforSASNotSupported.into()), + } + } + + #[cfg(test)] + pub(crate) async fn get_blob_tagging(&self, path: &Path) -> Result { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + let response = self + .client + .get(url.as_str()) + .query(&[("comp", "tags")]) + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .send() + .await + .map_err(|source| { + let path = path.as_ref().into(); + Error::GetRequest { source, path } + })?; + + Ok(response) + } +} + +#[async_trait] +impl GetClient for AzureClient { + const STORE: &'static str = STORE; + + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: true, + last_modified_required: true, + version_header: Some(VERSION_HEADER), + user_defined_metadata_prefix: Some(USER_DEFINED_METADATA_HEADER_PREFIX), + }; + + /// Make an Azure GET request + /// + /// + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + // As of 2024-01-02, Azure does not support suffix requests, + // so we should fail fast here rather than sending one + if let Some(GetRange::Suffix(_)) = options.range.as_ref() { + return Err(crate::Error::NotSupported { + source: "Azure does not support suffix range requests".into(), + }); + } + + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + let method = match options.head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut builder = self + .client + .request(method, url.as_str()) + .header(CONTENT_LENGTH, HeaderValue::from_static("0")) + .body(Bytes::new()); + + if let Some(v) = &options.version { + builder = builder.query(&[("versionid", v)]) + } + + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + let response = builder + .with_get_options(options) + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .send() + .await + .map_err(|source| { + let path = path.as_ref().into(); + Error::GetRequest { source, path } + })?; + + match response.headers().get("x-ms-resource-type") { + Some(resource) if resource.as_ref() != b"file" => Err(crate::Error::NotFound { + path: path.to_string(), + source: format!( + "Not a file, got x-ms-resource-type: {}", + String::from_utf8_lossy(resource.as_ref()) + ) + .into(), + }), + _ => Ok(response), + } + } +} + +#[async_trait] +impl ListClient for Arc { + /// Make an Azure List request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + token: Option<&str>, + offset: Option<&str>, + ) -> Result<(ListResult, Option)> { + assert!(offset.is_none()); // Not yet supported + + let credential = self.get_credential().await?; + let url = self.config.path_url(&Path::default()); + + let mut query = Vec::with_capacity(5); + query.push(("restype", "container")); + query.push(("comp", "list")); + + if let Some(prefix) = prefix { + query.push(("prefix", prefix)) + } + + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + if let Some(token) = token { + query.push(("marker", token)) + } + + let sensitive = credential + .as_deref() + .map(|c| c.sensitive_request()) + .unwrap_or_default(); + let response = self + .client + .get(url.as_str()) + .query(&query) + .with_azure_authorization(&credential, &self.config.account) + .retryable(&self.config.retry_config) + .sensitive(sensitive) + .send() + .await + .map_err(|source| Error::ListRequest { source })? + .into_body() + .bytes() + .await + .map_err(|source| Error::ListResponseBody { source })?; + + let mut response: ListResultInternal = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidListResponse { source })?; + + let token = response.next_marker.take(); + + Ok((to_list_result(response, prefix)?, token)) + } +} + +/// Raw / internal response from list requests +#[derive(Debug, Clone, PartialEq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct ListResultInternal { + pub prefix: Option, + pub max_results: Option, + pub delimiter: Option, + pub next_marker: Option, + pub blobs: Blobs, +} + +fn to_list_result(value: ListResultInternal, prefix: Option<&str>) -> Result { + let prefix = prefix.unwrap_or_default(); + let common_prefixes = value + .blobs + .blob_prefix + .into_iter() + .map(|x| Ok(Path::parse(x.name)?)) + .collect::>()?; + + let objects = value + .blobs + .blobs + .into_iter() + // Note: Filters out directories from list results when hierarchical namespaces are + // enabled. When we want directories, its always via the BlobPrefix mechanics, + // and during lists we state that prefixes are evaluated on path segment basis. + .filter(|blob| { + !matches!(blob.properties.resource_type.as_ref(), Some(typ) if typ == "directory") + && blob.name.len() > prefix.len() + }) + .map(ObjectMeta::try_from) + .collect::>()?; + + Ok(ListResult { + common_prefixes, + objects, + }) +} + +/// Collection of blobs and potentially shared prefixes returned from list requests. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct Blobs { + #[serde(default)] + pub blob_prefix: Vec, + #[serde(rename = "Blob", default)] + pub blobs: Vec, +} + +/// Common prefix in list blobs response +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct BlobPrefix { + pub name: String, +} + +/// Details for a specific blob +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct Blob { + pub name: String, + pub version_id: Option, + pub is_current_version: Option, + pub deleted: Option, + pub properties: BlobProperties, + pub metadata: Option>, +} + +impl TryFrom for ObjectMeta { + type Error = crate::Error; + + fn try_from(value: Blob) -> Result { + Ok(Self { + location: Path::parse(value.name)?, + last_modified: value.properties.last_modified, + size: value.properties.content_length, + e_tag: value.properties.e_tag, + version: None, // For consistency with S3 and GCP which don't include this + }) + } +} + +/// Properties associated with individual blobs. The actual list +/// of returned properties is much more exhaustive, but we limit +/// the parsed fields to the ones relevant in this crate. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct BlobProperties { + #[serde(deserialize_with = "deserialize_rfc1123", rename = "Last-Modified")] + pub last_modified: DateTime, + #[serde(rename = "Content-Length")] + pub content_length: u64, + #[serde(rename = "Content-Type")] + pub content_type: String, + #[serde(rename = "Content-Encoding")] + pub content_encoding: Option, + #[serde(rename = "Content-Language")] + pub content_language: Option, + #[serde(rename = "Etag")] + pub e_tag: Option, + #[serde(rename = "ResourceType")] + pub resource_type: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct BlockId(Bytes); + +impl BlockId { + pub(crate) fn new(block_id: impl Into) -> Self { + Self(block_id.into()) + } +} + +impl From for BlockId +where + B: Into, +{ + fn from(v: B) -> Self { + Self::new(v) + } +} + +impl AsRef<[u8]> for BlockId { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub(crate) struct BlockList { + pub blocks: Vec, +} + +impl BlockList { + pub(crate) fn to_xml(&self) -> String { + let mut s = String::new(); + s.push_str("\n\n"); + for block_id in &self.blocks { + let node = format!( + "\t{}\n", + BASE64_STANDARD.encode(block_id) + ); + s.push_str(&node); + } + + s.push_str(""); + s + } +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub(crate) struct UserDelegationKey { + pub signed_oid: String, + pub signed_tid: String, + pub signed_start: String, + pub signed_expiry: String, + pub signed_service: String, + pub signed_version: String, + pub value: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::StaticCredentialProvider; + use bytes::Bytes; + use regex::bytes::Regex; + use reqwest::Client; + + #[test] + fn deserde_azure() { + const S: &str = " + + + + blob0.txt + + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 07 Jul 2022 14:38:48 GMT + 0x8D93C7D4629C227 + 8 + text/plain + + + + rvr3UC1SmUw7AZV2NqPN0g== + + + BlockBlob + Hot + true + unlocked + available + true + + uservalue + + + + blob1.txt + + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 01 Jul 2021 10:44:59 GMT + 0x8D93C7D463004D6 + 8 + text/plain + + + + rvr3UC1SmUw7AZV2NqPN0g== + + + BlockBlob + Hot + true + unlocked + available + true + + + + + blob2.txt + + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 01 Jul 2021 10:44:59 GMT + 0x8D93C7D4636478A + 8 + text/plain + + + + rvr3UC1SmUw7AZV2NqPN0g== + + + BlockBlob + Hot + true + unlocked + available + true + + + + + +"; + + let mut _list_blobs_response_internal: ListResultInternal = + quick_xml::de::from_str(S).unwrap(); + } + + #[test] + fn deserde_azurite() { + const S: &str = " + + + + 5000 + + + + blob0.txt + + Thu, 01 Jul 2021 10:45:02 GMT + Thu, 01 Jul 2021 10:45:02 GMT + 0x228281B5D517B20 + 8 + text/plain + rvr3UC1SmUw7AZV2NqPN0g== + BlockBlob + unlocked + available + true + Hot + true + Thu, 01 Jul 2021 10:45:02 GMT + + + + blob1.txt + + Thu, 01 Jul 2021 10:45:02 GMT + Thu, 01 Jul 2021 10:45:02 GMT + 0x1DD959381A8A860 + 8 + text/plain + rvr3UC1SmUw7AZV2NqPN0g== + BlockBlob + unlocked + available + true + Hot + true + Thu, 01 Jul 2021 10:45:02 GMT + + + + blob2.txt + + Thu, 01 Jul 2021 10:45:02 GMT + Thu, 01 Jul 2021 10:45:02 GMT + 0x1FBE9C9B0C7B650 + 8 + text/plain + rvr3UC1SmUw7AZV2NqPN0g== + BlockBlob + unlocked + available + true + Hot + true + Thu, 01 Jul 2021 10:45:02 GMT + + + + +"; + + let _list_blobs_response_internal: ListResultInternal = quick_xml::de::from_str(S).unwrap(); + } + + #[test] + fn to_xml() { + const S: &str = " + +\tbnVtZXJvMQ== +\tbnVtZXJvMg== +\tbnVtZXJvMw== +"; + let mut blocks = BlockList { blocks: Vec::new() }; + blocks.blocks.push(Bytes::from_static(b"numero1").into()); + blocks.blocks.push("numero2".into()); + blocks.blocks.push("numero3".into()); + + let res: &str = &blocks.to_xml(); + + assert_eq!(res, S) + } + + #[test] + fn test_delegated_key_response() { + const S: &str = r#" + + String containing a GUID value + String containing a GUID value + String formatted as ISO date + String formatted as ISO date + b + String specifying REST api version to use to create the user delegation key + String containing the user delegation key +"#; + + let _delegated_key_response_internal: UserDelegationKey = + quick_xml::de::from_str(S).unwrap(); + } + + #[tokio::test] + async fn test_build_bulk_delete_body() { + let credential_provider = Arc::new(StaticCredentialProvider::new( + AzureCredential::BearerToken("static-token".to_string()), + )); + + let config = AzureConfig { + account: "testaccount".to_string(), + container: "testcontainer".to_string(), + credentials: credential_provider, + service: "http://example.com".try_into().unwrap(), + retry_config: Default::default(), + is_emulator: false, + skip_signature: false, + disable_tagging: false, + client_options: Default::default(), + }; + + let client = AzureClient::new(config, HttpClient::new(Client::new())); + + let credential = client.get_credential().await.unwrap(); + let paths = &[Path::from("a"), Path::from("b"), Path::from("c")]; + + let boundary = "batch_statictestboundary".to_string(); + + let body_bytes = client.build_bulk_delete_body(&boundary, paths, &credential); + + // Replace Date header value with a static date + let re = Regex::new("Date:[^\r]+").unwrap(); + let body_bytes = re + .replace_all(&body_bytes, b"Date: Tue, 05 Nov 2024 15:01:15 GMT") + .to_vec(); + + let expected_body = b"--batch_statictestboundary\r +Content-Type: application/http\r +Content-Transfer-Encoding: binary\r +Content-ID: 0\r +\r +DELETE /testcontainer/a HTTP/1.1\r +Content-Length: 0\r +Date: Tue, 05 Nov 2024 15:01:15 GMT\r +X-Ms-Version: 2023-11-03\r +Authorization: Bearer static-token\r +\r +\r +--batch_statictestboundary\r +Content-Type: application/http\r +Content-Transfer-Encoding: binary\r +Content-ID: 1\r +\r +DELETE /testcontainer/b HTTP/1.1\r +Content-Length: 0\r +Date: Tue, 05 Nov 2024 15:01:15 GMT\r +X-Ms-Version: 2023-11-03\r +Authorization: Bearer static-token\r +\r +\r +--batch_statictestboundary\r +Content-Type: application/http\r +Content-Transfer-Encoding: binary\r +Content-ID: 2\r +\r +DELETE /testcontainer/c HTTP/1.1\r +Content-Length: 0\r +Date: Tue, 05 Nov 2024 15:01:15 GMT\r +X-Ms-Version: 2023-11-03\r +Authorization: Bearer static-token\r +\r +\r +--batch_statictestboundary--\r\n" + .to_vec(); + + assert_eq!(expected_body, body_bytes); + } + + #[tokio::test] + async fn test_parse_blob_batch_delete_body() { + let response_body = b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r +Content-Type: application/http\r +Content-ID: 0\r +\r +HTTP/1.1 202 Accepted\r +x-ms-delete-type-permanent: true\r +x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r +x-ms-version: 2018-11-09\r +\r +--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r +Content-Type: application/http\r +Content-ID: 1\r +\r +HTTP/1.1 202 Accepted\r +x-ms-delete-type-permanent: true\r +x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2851\r +x-ms-version: 2018-11-09\r +\r +--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r +Content-Type: application/http\r +Content-ID: 2\r +\r +HTTP/1.1 404 The specified blob does not exist.\r +x-ms-error-code: BlobNotFound\r +x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852\r +x-ms-version: 2018-11-09\r +Content-Length: 216\r +Content-Type: application/xml\r +\r + +BlobNotFoundThe specified blob does not exist. +RequestId:778fdc83-801e-0000-62ff-0334671e2852 +Time:2018-06-14T16:46:54.6040685Z\r +--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n"; + + let response: HttpResponse = http::Response::builder() + .status(202) + .header("Transfer-Encoding", "chunked") + .header( + "Content-Type", + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed", + ) + .header("x-ms-request-id", "778fdc83-801e-0000-62ff-033467000000") + .header("x-ms-version", "2018-11-09") + .body(Bytes::from(response_body.as_slice()).into()) + .unwrap(); + + let boundary = parse_multipart_response_boundary(&response).unwrap(); + let body = response.into_body().bytes().await.unwrap(); + + let paths = &[Path::from("a"), Path::from("b"), Path::from("c")]; + + let results = parse_blob_batch_delete_body(body, boundary, paths) + .await + .unwrap(); + + assert!(results[0].is_ok()); + assert_eq!(&paths[0], results[0].as_ref().unwrap()); + + assert!(results[1].is_ok()); + assert_eq!(&paths[1], results[1].as_ref().unwrap()); + + assert!(results[2].is_err()); + let err = results[2].as_ref().unwrap_err(); + let crate::Error::NotFound { source, .. } = err else { + unreachable!("must be not found") + }; + let Some(Error::DeleteFailed { path, code, reason }) = source.downcast_ref::() + else { + unreachable!("must be client error") + }; + + assert_eq!(paths[2].as_ref(), path); + assert_eq!("404", code); + assert_eq!("The specified blob does not exist.", reason); + } +} diff --git a/src/azure/credential.rs b/src/azure/credential.rs new file mode 100644 index 0000000..27f8776 --- /dev/null +++ b/src/azure/credential.rs @@ -0,0 +1,1220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::client::UserDelegationKey; +use crate::azure::STORE; +use crate::client::builder::{add_query_pairs, HttpRequestBuilder}; +use crate::client::retry::RetryExt; +use crate::client::token::{TemporaryToken, TokenCache}; +use crate::client::{CredentialProvider, HttpClient, HttpError, HttpRequest, TokenProvider}; +use crate::util::hmac_sha256; +use crate::RetryConfig; +use async_trait::async_trait; +use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD}; +use base64::Engine; +use chrono::{DateTime, SecondsFormat, Utc}; +use http::header::{ + HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE, + CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, + IF_UNMODIFIED_SINCE, RANGE, +}; +use http::Method; +use serde::Deserialize; +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::Deref; +use std::process::Command; +use std::str; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime}; +use url::Url; + +static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03"); +static VERSION: HeaderName = HeaderName::from_static("x-ms-version"); +pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type"); +pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots"); +pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source"); +static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5"); +static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token"); +static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier"); +static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker"); +static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host"); +pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT"; +const CONTENT_TYPE_JSON: &str = "application/json"; +const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER"; +const MSI_API_VERSION: &str = "2019-08-01"; +const TOKEN_MIN_TTL: u64 = 300; + +/// OIDC scope used when interacting with OAuth2 APIs +/// +/// +const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default"; + +/// Resource ID used when obtaining an access token from the metadata endpoint +/// +/// +const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com"; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Error performing token request: {}", source)] + TokenRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting token response body: {}", source)] + TokenResponseBody { source: HttpError }, + + #[error("Error reading federated token file ")] + FederatedTokenFile, + + #[error("Invalid Access Key: {}", source)] + InvalidAccessKey { source: base64::DecodeError }, + + #[error("'az account get-access-token' command failed: {message}")] + AzureCli { message: String }, + + #[error("Failed to parse azure cli response: {source}")] + AzureCliResponse { source: serde_json::Error }, + + #[error("Generating SAS keys with SAS tokens auth is not supported")] + SASforSASNotSupported, +} + +pub(crate) type Result = std::result::Result; + +impl From for crate::Error { + fn from(value: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(value), + } + } +} + +/// A shared Azure Storage Account Key +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct AzureAccessKey(Vec); + +impl AzureAccessKey { + /// Create a new [`AzureAccessKey`], checking it for validity + pub fn try_new(key: &str) -> Result { + let key = BASE64_STANDARD + .decode(key) + .map_err(|source| Error::InvalidAccessKey { source })?; + + Ok(Self(key)) + } +} + +/// An Azure storage credential +#[derive(Debug, Eq, PartialEq)] +pub enum AzureCredential { + /// A shared access key + /// + /// + AccessKey(AzureAccessKey), + /// A shared access signature + /// + /// + SASToken(Vec<(String, String)>), + /// An authorization token + /// + /// + BearerToken(String), +} + +impl AzureCredential { + /// Determines if the credential requires the request be treated as sensitive + pub fn sensitive_request(&self) -> bool { + match self { + Self::AccessKey(_) => false, + Self::BearerToken(_) => false, + // SAS tokens are sent as query parameters in the url + Self::SASToken(_) => true, + } + } +} + +/// A list of known Azure authority hosts +pub mod authority_hosts { + /// China-based Azure Authority Host + pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn"; + /// Germany-based Azure Authority Host + pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de"; + /// US Government Azure Authority Host + pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us"; + /// Public Cloud Azure Authority Host + pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com"; +} + +pub(crate) struct AzureSigner { + signing_key: AzureAccessKey, + start: DateTime, + end: DateTime, + account: String, + delegation_key: Option, +} + +impl AzureSigner { + pub(crate) fn new( + signing_key: AzureAccessKey, + account: String, + start: DateTime, + end: DateTime, + delegation_key: Option, + ) -> Self { + Self { + signing_key, + account, + start, + end, + delegation_key, + } + } + + pub(crate) fn sign(&self, method: &Method, url: &mut Url) -> Result<()> { + let (str_to_sign, query_pairs) = match &self.delegation_key { + Some(delegation_key) => string_to_sign_user_delegation_sas( + url, + method, + &self.account, + &self.start, + &self.end, + delegation_key, + ), + None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end), + }; + let auth = hmac_sha256(&self.signing_key.0, str_to_sign); + url.query_pairs_mut().extend_pairs(query_pairs); + url.query_pairs_mut() + .append_pair("sig", BASE64_STANDARD.encode(auth).as_str()); + Ok(()) + } +} + +fn add_date_and_version_headers(request: &mut HttpRequest) { + // rfc2822 string should never contain illegal characters + let date = Utc::now(); + let date_str = date.format(RFC1123_FMT).to_string(); + // we formatted the data string ourselves, so unwrapping should be fine + let date_val = HeaderValue::from_str(&date_str).unwrap(); + request.headers_mut().insert(DATE, date_val); + request + .headers_mut() + .insert(&VERSION, AZURE_VERSION.clone()); +} + +/// Authorize a [`HttpRequest`] with an [`AzureAuthorizer`] +#[derive(Debug)] +pub struct AzureAuthorizer<'a> { + credential: &'a AzureCredential, + account: &'a str, +} + +impl<'a> AzureAuthorizer<'a> { + /// Create a new [`AzureAuthorizer`] + pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self { + AzureAuthorizer { + credential, + account, + } + } + + /// Authorize `request` + pub fn authorize(&self, request: &mut HttpRequest) { + add_date_and_version_headers(request); + + match self.credential { + AzureCredential::AccessKey(key) => { + let url = Url::parse(&request.uri().to_string()).unwrap(); + let signature = generate_authorization( + request.headers(), + &url, + request.method(), + self.account, + key, + ); + + // "signature" is a base 64 encoded string so it should never + // contain illegal characters + request.headers_mut().append( + AUTHORIZATION, + HeaderValue::from_str(signature.as_str()).unwrap(), + ); + } + AzureCredential::BearerToken(token) => { + request.headers_mut().append( + AUTHORIZATION, + HeaderValue::from_str(format!("Bearer {}", token).as_str()).unwrap(), + ); + } + AzureCredential::SASToken(query_pairs) => { + add_query_pairs(request.uri_mut(), query_pairs); + } + } + } +} + +pub(crate) trait CredentialExt { + /// Apply authorization to requests against azure storage accounts + /// + fn with_azure_authorization( + self, + credential: &Option>, + account: &str, + ) -> Self; +} + +impl CredentialExt for HttpRequestBuilder { + fn with_azure_authorization( + self, + credential: &Option>, + account: &str, + ) -> Self { + let (client, request) = self.into_parts(); + let mut request = request.expect("request valid"); + + match credential.as_deref() { + Some(credential) => { + AzureAuthorizer::new(credential, account).authorize(&mut request); + } + None => { + add_date_and_version_headers(&mut request); + } + } + + Self::from_parts(client, request) + } +} + +/// Generate signed key for authorization via access keys +/// +fn generate_authorization( + h: &HeaderMap, + u: &Url, + method: &Method, + account: &str, + key: &AzureAccessKey, +) -> String { + let str_to_sign = string_to_sign(h, u, method, account); + let auth = hmac_sha256(&key.0, str_to_sign); + format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth)) +} + +fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str { + h.get(key) + .map(|s| s.to_str()) + .transpose() + .ok() + .flatten() + .unwrap_or_default() +} + +fn string_to_sign_sas( + u: &Url, + method: &Method, + account: &str, + start: &DateTime, + end: &DateTime, +) -> (String, String, String, String, String) { + // NOTE: for now only blob signing is supported. + let signed_resource = "b".to_string(); + + // https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob + let signed_permissions = match *method { + // read and list permissions + Method::GET => match signed_resource.as_str() { + "c" => "rl", + "b" => "r", + _ => unreachable!(), + }, + // write permissions (also allows crating a new blob in a sub-key) + Method::PUT => "w", + // delete permissions + Method::DELETE => "d", + // other methods are not used in any of the current operations + _ => "", + } + .to_string(); + let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true); + let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true); + let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) { + format!("/blob/{}{}", account, u.path()) + } else { + // NOTE: in case of the emulator, the account name is not part of the host + // but the path starts with the account name + format!("/blob{}", u.path()) + }; + + ( + signed_resource, + signed_permissions, + signed_start, + signed_expiry, + canonicalized_resource, + ) +} + +/// Create a string to be signed for authorization via [service sas]. +/// +/// [service sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#version-2020-12-06-and-later +fn string_to_sign_service_sas( + u: &Url, + method: &Method, + account: &str, + start: &DateTime, + end: &DateTime, +) -> (String, HashMap<&'static str, String>) { + let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) = + string_to_sign_sas(u, method, account, start, end); + + let string_to_sign = format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}", + signed_permissions, + signed_start, + signed_expiry, + canonicalized_resource, + "", // signed identifier + "", // signed ip + "", // signed protocol + &AZURE_VERSION.to_str().unwrap(), // signed version + signed_resource, // signed resource + "", // signed snapshot time + "", // signed encryption scope + "", // rscc - response header: Cache-Control + "", // rscd - response header: Content-Disposition + "", // rsce - response header: Content-Encoding + "", // rscl - response header: Content-Language + "", // rsct - response header: Content-Type + ); + + let mut pairs = HashMap::new(); + pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string()); + pairs.insert("sp", signed_permissions); + pairs.insert("st", signed_start); + pairs.insert("se", signed_expiry); + pairs.insert("sr", signed_resource); + + (string_to_sign, pairs) +} + +/// Create a string to be signed for authorization via [user delegation sas]. +/// +/// [user delegation sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-user-delegation-sas#version-2020-12-06-and-later +fn string_to_sign_user_delegation_sas( + u: &Url, + method: &Method, + account: &str, + start: &DateTime, + end: &DateTime, + delegation_key: &UserDelegationKey, +) -> (String, HashMap<&'static str, String>) { + let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) = + string_to_sign_sas(u, method, account, start, end); + + let string_to_sign = format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}", + signed_permissions, + signed_start, + signed_expiry, + canonicalized_resource, + delegation_key.signed_oid, // signed key object id + delegation_key.signed_tid, // signed key tenant id + delegation_key.signed_start, // signed key start + delegation_key.signed_expiry, // signed key expiry + delegation_key.signed_service, // signed key service + delegation_key.signed_version, // signed key version + "", // signed authorized user object id + "", // signed unauthorized user object id + "", // signed correlation id + "", // signed ip + "", // signed protocol + &AZURE_VERSION.to_str().unwrap(), // signed version + signed_resource, // signed resource + "", // signed snapshot time + "", // signed encryption scope + "", // rscc - response header: Cache-Control + "", // rscd - response header: Content-Disposition + "", // rsce - response header: Content-Encoding + "", // rscl - response header: Content-Language + "", // rsct - response header: Content-Type + ); + + let mut pairs = HashMap::new(); + pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string()); + pairs.insert("sp", signed_permissions); + pairs.insert("st", signed_start); + pairs.insert("se", signed_expiry); + pairs.insert("sr", signed_resource); + pairs.insert("skoid", delegation_key.signed_oid.clone()); + pairs.insert("sktid", delegation_key.signed_tid.clone()); + pairs.insert("skt", delegation_key.signed_start.clone()); + pairs.insert("ske", delegation_key.signed_expiry.clone()); + pairs.insert("sks", delegation_key.signed_service.clone()); + pairs.insert("skv", delegation_key.signed_version.clone()); + + (string_to_sign, pairs) +} + +/// +fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String { + // content length must only be specified if != 0 + // this is valid from 2015-02-21 + let content_length = h + .get(&CONTENT_LENGTH) + .map(|s| s.to_str()) + .transpose() + .ok() + .flatten() + .filter(|&v| v != "0") + .unwrap_or_default(); + format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}", + method.as_ref(), + add_if_exists(h, &CONTENT_ENCODING), + add_if_exists(h, &CONTENT_LANGUAGE), + content_length, + add_if_exists(h, &CONTENT_MD5), + add_if_exists(h, &CONTENT_TYPE), + add_if_exists(h, &DATE), + add_if_exists(h, &IF_MODIFIED_SINCE), + add_if_exists(h, &IF_MATCH), + add_if_exists(h, &IF_NONE_MATCH), + add_if_exists(h, &IF_UNMODIFIED_SINCE), + add_if_exists(h, &RANGE), + canonicalize_header(h), + canonicalize_resource(account, u) + ) +} + +/// +fn canonicalize_header(headers: &HeaderMap) -> String { + let mut names = headers + .iter() + .filter(|&(k, _)| (k.as_str().starts_with("x-ms"))) + // TODO remove unwraps + .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap())) + .collect::>(); + names.sort_unstable(); + + let mut result = String::new(); + for (name, value) in names { + result.push_str(name); + result.push(':'); + result.push_str(value); + result.push('\n'); + } + result +} + +/// +fn canonicalize_resource(account: &str, uri: &Url) -> String { + let mut can_res: String = String::new(); + can_res.push('/'); + can_res.push_str(account); + can_res.push_str(uri.path().to_string().as_str()); + can_res.push('\n'); + + // query parameters + let query_pairs = uri.query_pairs(); + { + let mut qps: Vec = Vec::new(); + for (q, _) in query_pairs { + if !(qps.iter().any(|x| x == &*q)) { + qps.push(q.into_owned()); + } + } + + qps.sort(); + + for qparam in qps { + // find correct parameter + let ret = lexy_sort(query_pairs, &qparam); + + can_res = can_res + &qparam.to_lowercase() + ":"; + + for (i, item) in ret.iter().enumerate() { + if i > 0 { + can_res.push(','); + } + can_res.push_str(item); + } + + can_res.push('\n'); + } + }; + + can_res[0..can_res.len() - 1].to_owned() +} + +fn lexy_sort<'a>( + vec: impl Iterator, Cow<'a, str>)> + 'a, + query_param: &str, +) -> Vec> { + let mut values = vec + .filter(|(k, _)| *k == query_param) + .map(|(_, v)| v) + .collect::>(); + values.sort_unstable(); + values +} + +/// +#[derive(Deserialize, Debug)] +struct OAuthTokenResponse { + access_token: String, + expires_in: u64, +} + +/// Encapsulates the logic to perform an OAuth token challenge +/// +/// +#[derive(Debug)] +pub(crate) struct ClientSecretOAuthProvider { + token_url: String, + client_id: String, + client_secret: String, +} + +impl ClientSecretOAuthProvider { + /// Create a new [`ClientSecretOAuthProvider`] for an azure backed store + pub(crate) fn new( + client_id: String, + client_secret: String, + tenant_id: impl AsRef, + authority_host: Option, + ) -> Self { + let authority_host = + authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned()); + + Self { + token_url: format!( + "{}/{}/oauth2/v2.0/token", + authority_host, + tenant_id.as_ref() + ), + client_id, + client_secret, + } + } +} + +#[async_trait::async_trait] +impl TokenProvider for ClientSecretOAuthProvider { + type Credential = AzureCredential; + + /// Fetch a token + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let response: OAuthTokenResponse = client + .request(Method::POST, &self.token_url) + .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) + .form([ + ("client_id", self.client_id.as_str()), + ("client_secret", self.client_secret.as_str()), + ("scope", AZURE_STORAGE_SCOPE), + ("grant_type", "client_credentials"), + ]) + .retryable(retry) + .idempotent(true) + .send() + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .json() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(response.access_token)), + expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + }) + } +} + +fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result +where + D: serde::de::Deserializer<'de>, +{ + let v = String::deserialize(deserializer)?; + let v = v.parse::().map_err(serde::de::Error::custom)?; + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(serde::de::Error::custom)?; + + Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs()))) +} + +/// NOTE: expires_on is a String version of unix epoch time, not an integer. +/// +/// +#[derive(Debug, Clone, Deserialize)] +struct ImdsTokenResponse { + pub access_token: String, + #[serde(deserialize_with = "expires_on_string")] + pub expires_on: Instant, +} + +/// Attempts authentication using a managed identity that has been assigned to the deployment environment. +/// +/// This authentication type works in Azure VMs, App Service and Azure Functions applications, as well as the Azure Cloud Shell +/// +#[derive(Debug)] +pub(crate) struct ImdsManagedIdentityProvider { + msi_endpoint: String, + client_id: Option, + object_id: Option, + msi_res_id: Option, +} + +impl ImdsManagedIdentityProvider { + /// Create a new [`ImdsManagedIdentityProvider`] for an azure backed store + pub(crate) fn new( + client_id: Option, + object_id: Option, + msi_res_id: Option, + msi_endpoint: Option, + ) -> Self { + let msi_endpoint = msi_endpoint + .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned()); + + Self { + msi_endpoint, + client_id, + object_id, + msi_res_id, + } + } +} + +#[async_trait::async_trait] +impl TokenProvider for ImdsManagedIdentityProvider { + type Credential = AzureCredential; + + /// Fetch a token + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let mut query_items = vec![ + ("api-version", MSI_API_VERSION), + ("resource", AZURE_STORAGE_RESOURCE), + ]; + + let mut identity = None; + if let Some(client_id) = &self.client_id { + identity = Some(("client_id", client_id)); + } + if let Some(object_id) = &self.object_id { + identity = Some(("object_id", object_id)); + } + if let Some(msi_res_id) = &self.msi_res_id { + identity = Some(("msi_res_id", msi_res_id)); + } + if let Some((key, value)) = identity { + query_items.push((key, value)); + } + + let mut builder = client + .request(Method::GET, &self.msi_endpoint) + .header("metadata", "true") + .query(&query_items); + + if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) { + builder = builder.header("x-identity-header", val); + }; + + let response: ImdsTokenResponse = builder + .send_retry(retry) + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .json() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(response.access_token)), + expiry: Some(response.expires_on), + }) + } +} + +/// Credential for using workload identity federation +/// +/// +#[derive(Debug)] +pub(crate) struct WorkloadIdentityOAuthProvider { + token_url: String, + client_id: String, + federated_token_file: String, +} + +impl WorkloadIdentityOAuthProvider { + /// Create a new [`WorkloadIdentityOAuthProvider`] for an azure backed store + pub(crate) fn new( + client_id: impl Into, + federated_token_file: impl Into, + tenant_id: impl AsRef, + authority_host: Option, + ) -> Self { + let authority_host = + authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned()); + + Self { + token_url: format!( + "{}/{}/oauth2/v2.0/token", + authority_host, + tenant_id.as_ref() + ), + client_id: client_id.into(), + federated_token_file: federated_token_file.into(), + } + } +} + +#[async_trait::async_trait] +impl TokenProvider for WorkloadIdentityOAuthProvider { + type Credential = AzureCredential; + + /// Fetch a token + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let token_str = std::fs::read_to_string(&self.federated_token_file) + .map_err(|_| Error::FederatedTokenFile)?; + + // https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#third-case-access-token-request-with-a-federated-credential + let response: OAuthTokenResponse = client + .request(Method::POST, &self.token_url) + .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) + .form([ + ("client_id", self.client_id.as_str()), + ( + "client_assertion_type", + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + ), + ("client_assertion", token_str.as_str()), + ("scope", AZURE_STORAGE_SCOPE), + ("grant_type", "client_credentials"), + ]) + .retryable(retry) + .idempotent(true) + .send() + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .json() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(response.access_token)), + expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + }) + } +} + +mod az_cli_date_format { + use chrono::{DateTime, TimeZone}; + use serde::{self, Deserialize, Deserializer}; + + pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + // expiresOn from azure cli uses the local timezone + let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f") + .map_err(serde::de::Error::custom)?; + chrono::Local + .from_local_datetime(&date) + .single() + .ok_or(serde::de::Error::custom( + "azure cli returned ambiguous expiry date", + )) + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct AzureCliTokenResponse { + pub access_token: String, + #[serde(with = "az_cli_date_format")] + pub expires_on: DateTime, + pub token_type: String, +} + +#[derive(Default, Debug)] +pub(crate) struct AzureCliCredential { + cache: TokenCache>, +} + +impl AzureCliCredential { + pub(crate) fn new() -> Self { + Self::default() + } + + /// Fetch a token + async fn fetch_token(&self) -> Result>> { + // on window az is a cmd and it should be called like this + // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html + let program = if cfg!(target_os = "windows") { + "cmd" + } else { + "az" + }; + let mut args = Vec::new(); + if cfg!(target_os = "windows") { + args.push("/C"); + args.push("az"); + } + args.push("account"); + args.push("get-access-token"); + args.push("--output"); + args.push("json"); + args.push("--scope"); + args.push(AZURE_STORAGE_SCOPE); + + match Command::new(program).args(args).output() { + Ok(az_output) if az_output.status.success() => { + let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli { + message: "az response is not a valid utf-8 string".to_string(), + })?; + + let token_response = serde_json::from_str::(output) + .map_err(|source| Error::AzureCliResponse { source })?; + + if !token_response.token_type.eq_ignore_ascii_case("bearer") { + return Err(Error::AzureCli { + message: format!( + "got unexpected token type from azure cli: {0}", + token_response.token_type + ), + }); + } + let duration = + token_response.expires_on.naive_local() - chrono::Local::now().naive_local(); + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(token_response.access_token)), + expiry: Some( + Instant::now() + + duration.to_std().map_err(|_| Error::AzureCli { + message: "az returned invalid lifetime".to_string(), + })?, + ), + }) + } + Ok(az_output) => { + let message = String::from_utf8_lossy(&az_output.stderr); + Err(Error::AzureCli { + message: message.into(), + }) + } + Err(e) => match e.kind() { + std::io::ErrorKind::NotFound => Err(Error::AzureCli { + message: "Azure Cli not installed".into(), + }), + error_kind => Err(Error::AzureCli { + message: format!("io error: {error_kind:?}"), + }), + }, + } + } +} + +/// Encapsulates the logic to perform an OAuth token challenge for Fabric +#[derive(Debug)] +pub(crate) struct FabricTokenOAuthProvider { + fabric_token_service_url: String, + fabric_workload_host: String, + fabric_session_token: String, + fabric_cluster_identifier: String, + storage_access_token: Option, + token_expiry: Option, +} + +#[derive(Debug, Deserialize)] +struct Claims { + exp: u64, +} + +impl FabricTokenOAuthProvider { + /// Create a new [`FabricTokenOAuthProvider`] for an azure backed store + pub(crate) fn new( + fabric_token_service_url: impl Into, + fabric_workload_host: impl Into, + fabric_session_token: impl Into, + fabric_cluster_identifier: impl Into, + storage_access_token: Option, + ) -> Self { + let (storage_access_token, token_expiry) = match storage_access_token { + Some(token) => match Self::validate_and_get_expiry(&token) { + Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => { + (Some(token), Some(expiry)) + } + _ => (None, None), + }, + None => (None, None), + }; + + Self { + fabric_token_service_url: fabric_token_service_url.into(), + fabric_workload_host: fabric_workload_host.into(), + fabric_session_token: fabric_session_token.into(), + fabric_cluster_identifier: fabric_cluster_identifier.into(), + storage_access_token, + token_expiry, + } + } + + fn validate_and_get_expiry(token: &str) -> Option { + let payload = token.split('.').nth(1)?; + let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?; + let decoded_str = str::from_utf8(&decoded_bytes).ok()?; + let claims: Claims = serde_json::from_str(decoded_str).ok()?; + Some(claims.exp) + } + + fn get_current_timestamp() -> u64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_secs()) + } +} + +#[async_trait::async_trait] +impl TokenProvider for FabricTokenOAuthProvider { + type Credential = AzureCredential; + + /// Fetch a token + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + if let Some(storage_access_token) = &self.storage_access_token { + if let Some(expiry) = self.token_expiry { + let exp_in = expiry - Self::get_current_timestamp(); + if exp_in > TOKEN_MIN_TTL { + return Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())), + expiry: Some(Instant::now() + Duration::from_secs(exp_in)), + }); + } + } + } + + let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)]; + let access_token: String = client + .request(Method::GET, &self.fabric_token_service_url) + .header(&PARTNER_TOKEN, self.fabric_session_token.as_str()) + .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str()) + .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str()) + .header(&PROXY_HOST, self.fabric_workload_host.as_str()) + .query(&query_items) + .retryable(retry) + .idempotent(true) + .send() + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .text() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + let exp_in = Self::validate_and_get_expiry(&access_token) + .map_or(3600, |expiry| expiry - Self::get_current_timestamp()); + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(access_token)), + expiry: Some(Instant::now() + Duration::from_secs(exp_in)), + }) + } +} + +#[async_trait] +impl CredentialProvider for AzureCliCredential { + type Credential = AzureCredential; + + async fn get_credential(&self) -> crate::Result> { + Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?) + } +} + +#[cfg(test)] +mod tests { + use futures::executor::block_on; + use http::{Response, StatusCode}; + use http_body_util::BodyExt; + use reqwest::{Client, Method}; + use tempfile::NamedTempFile; + + use super::*; + use crate::azure::MicrosoftAzureBuilder; + use crate::client::mock_server::MockServer; + use crate::{ObjectStore, Path}; + + #[tokio::test] + async fn test_managed_identity() { + let server = MockServer::new().await; + + std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); + + let endpoint = server.url(); + let client = HttpClient::new(Client::new()); + let retry_config = RetryConfig::default(); + + // Test IMDS + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token"); + assert!(req.uri().query().unwrap().contains("client_id=client_id")); + assert_eq!(req.method(), &Method::GET); + let t = req + .headers() + .get("x-identity-header") + .unwrap() + .to_str() + .unwrap(); + assert_eq!(t, "env-secret"); + let t = req.headers().get("metadata").unwrap().to_str().unwrap(); + assert_eq!(t, "true"); + Response::new( + r#" + { + "access_token": "TOKEN", + "refresh_token": "", + "expires_in": "3599", + "expires_on": "1506484173", + "not_before": "1506480273", + "resource": "https://management.azure.com/", + "token_type": "Bearer" + } + "# + .to_string(), + ) + }); + + let credential = ImdsManagedIdentityProvider::new( + Some("client_id".into()), + None, + None, + Some(format!("{endpoint}/metadata/identity/oauth2/token")), + ); + + let token = credential + .fetch_token(&client, &retry_config) + .await + .unwrap(); + + assert_eq!( + token.token.as_ref(), + &AzureCredential::BearerToken("TOKEN".into()) + ); + } + + #[tokio::test] + async fn test_workload_identity() { + let server = MockServer::new().await; + let tokenfile = NamedTempFile::new().unwrap(); + let tenant = "tenant"; + std::fs::write(tokenfile.path(), "federated-token").unwrap(); + + let endpoint = server.url(); + let client = HttpClient::new(Client::new()); + let retry_config = RetryConfig::default(); + + // Test IMDS + server.push_fn(move |req| { + assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token")); + assert_eq!(req.method(), &Method::POST); + let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() }); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.contains("federated-token")); + Response::new( + r#" + { + "access_token": "TOKEN", + "refresh_token": "", + "expires_in": 3599, + "expires_on": "1506484173", + "not_before": "1506480273", + "resource": "https://management.azure.com/", + "token_type": "Bearer" + } + "# + .to_string(), + ) + }); + + let credential = WorkloadIdentityOAuthProvider::new( + "client_id", + tokenfile.path().to_str().unwrap(), + tenant, + Some(endpoint.to_string()), + ); + + let token = credential + .fetch_token(&client, &retry_config) + .await + .unwrap(); + + assert_eq!( + token.token.as_ref(), + &AzureCredential::BearerToken("TOKEN".into()) + ); + } + + #[tokio::test] + async fn test_no_credentials() { + let server = MockServer::new().await; + + let endpoint = server.url(); + let store = MicrosoftAzureBuilder::new() + .with_account("test") + .with_container_name("test") + .with_allow_http(true) + .with_bearer_token_authorization("token") + .with_endpoint(endpoint.to_string()) + .with_skip_signature(true) + .build() + .unwrap(); + + server.push_fn(|req| { + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get("Authorization").is_none()); + Response::builder() + .status(StatusCode::NOT_FOUND) + .body("not found".to_string()) + .unwrap() + }); + + let path = Path::from("file.txt"); + match store.get(&path).await { + Err(crate::Error::NotFound { .. }) => {} + _ => { + panic!("unexpected response"); + } + } + } +} diff --git a/src/azure/mod.rs b/src/azure/mod.rs new file mode 100644 index 0000000..b4243dd --- /dev/null +++ b/src/azure/mod.rs @@ -0,0 +1,396 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for Azure blob storage +//! +//! ## Streaming uploads +//! +//! [ObjectStore::put_multipart] will upload data in blocks and write a blob from those blocks. +//! +//! Unused blocks will automatically be dropped after 7 days. +use crate::{ + multipart::{MultipartStore, PartId}, + path::Path, + signer::Signer, + GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, UploadPart, +}; +use async_trait::async_trait; +use futures::stream::{BoxStream, StreamExt, TryStreamExt}; +use reqwest::Method; +use std::fmt::Debug; +use std::sync::Arc; +use std::time::Duration; +use url::Url; + +use crate::client::get::GetClientExt; +use crate::client::list::ListClientExt; +use crate::client::CredentialProvider; +pub use credential::{authority_hosts, AzureAccessKey, AzureAuthorizer}; + +mod builder; +mod client; +mod credential; + +/// [`CredentialProvider`] for [`MicrosoftAzure`] +pub type AzureCredentialProvider = Arc>; +use crate::azure::client::AzureClient; +use crate::client::parts::Parts; +pub use builder::{AzureConfigKey, MicrosoftAzureBuilder}; +pub use credential::AzureCredential; + +const STORE: &str = "MicrosoftAzure"; + +/// Interface for [Microsoft Azure Blob Storage](https://azure.microsoft.com/en-us/services/storage/blobs/). +#[derive(Debug)] +pub struct MicrosoftAzure { + client: Arc, +} + +impl MicrosoftAzure { + /// Returns the [`AzureCredentialProvider`] used by [`MicrosoftAzure`] + pub fn credentials(&self) -> &AzureCredentialProvider { + &self.client.config().credentials + } + + /// Create a full URL to the resource specified by `path` with this instance's configuration. + fn path_url(&self, path: &Path) -> Url { + self.client.config().path_url(path) + } +} + +impl std::fmt::Display for MicrosoftAzure { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "MicrosoftAzure {{ account: {}, container: {} }}", + self.client.config().account, + self.client.config().container + ) + } +} + +#[async_trait] +impl ObjectStore for MicrosoftAzure { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.client.put_blob(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + Ok(Box::new(AzureMultiPartUpload { + part_idx: 0, + opts, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + location: location.clone(), + parts: Default::default(), + }), + })) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.client.get_opts(location, options).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete_request(location, &()).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.client.list(prefix) + } + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + locations + .try_chunks(256) + .map(move |locations| async { + // Early return the error. We ignore the paths that have already been + // collected into the chunk. + let locations = locations.map_err(|e| e.1)?; + self.client + .bulk_delete_request(locations) + .await + .map(futures::stream::iter) + }) + .buffered(20) + .try_flatten() + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.client.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, true).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, false).await + } +} + +#[async_trait] +impl Signer for MicrosoftAzure { + /// Create a URL containing the relevant [Service SAS] query parameters that authorize a request + /// via `method` to the resource at `path` valid for the duration specified in `expires_in`. + /// + /// [Service SAS]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas + /// + /// # Example + /// + /// This example returns a URL that will enable a user to upload a file to + /// "some-folder/some-file.txt" in the next hour. + /// + /// ``` + /// # async fn example() -> Result<(), Box> { + /// # use object_store::{azure::MicrosoftAzureBuilder, path::Path, signer::Signer}; + /// # use reqwest::Method; + /// # use std::time::Duration; + /// # + /// let azure = MicrosoftAzureBuilder::new() + /// .with_account("my-account") + /// .with_access_key("my-access-key") + /// .with_container_name("my-container") + /// .build()?; + /// + /// let url = azure.signed_url( + /// Method::PUT, + /// &Path::from("some-folder/some-file.txt"), + /// Duration::from_secs(60 * 60) + /// ).await?; + /// # Ok(()) + /// # } + /// ``` + async fn signed_url(&self, method: Method, path: &Path, expires_in: Duration) -> Result { + let mut url = self.path_url(path); + let signer = self.client.signer(expires_in).await?; + signer.sign(&method, &mut url)?; + Ok(url) + } + + async fn signed_urls( + &self, + method: Method, + paths: &[Path], + expires_in: Duration, + ) -> Result> { + let mut urls = Vec::with_capacity(paths.len()); + let signer = self.client.signer(expires_in).await?; + for path in paths { + let mut url = self.path_url(path); + signer.sign(&method, &mut url)?; + urls.push(url); + } + Ok(urls) + } +} + +/// Relevant docs: +/// In Azure Blob Store, parts are "blocks" +/// put_multipart_part -> PUT block +/// complete -> PUT block list +/// abort -> No equivalent; blocks are simply dropped after 7 days +#[derive(Debug)] +struct AzureMultiPartUpload { + part_idx: usize, + state: Arc, + opts: PutMultipartOpts, +} + +#[derive(Debug)] +struct UploadState { + location: Path, + parts: Parts, + client: Arc, +} + +#[async_trait] +impl MultipartUpload for AzureMultiPartUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state.client.put_block(&state.location, idx, data).await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .put_block_list(&self.state.location, parts, std::mem::take(&mut self.opts)) + .await + } + + async fn abort(&mut self) -> Result<()> { + // Nothing to do + Ok(()) + } +} + +#[async_trait] +impl MultipartStore for MicrosoftAzure { + async fn create_multipart(&self, _: &Path) -> Result { + Ok(String::new()) + } + + async fn put_part( + &self, + path: &Path, + _: &MultipartId, + part_idx: usize, + data: PutPayload, + ) -> Result { + self.client.put_block(path, part_idx, data).await + } + + async fn complete_multipart( + &self, + path: &Path, + _: &MultipartId, + parts: Vec, + ) -> Result { + self.client + .put_block_list(path, parts, Default::default()) + .await + } + + async fn abort_multipart(&self, _: &Path, _: &MultipartId) -> Result<()> { + // There is no way to drop blocks that have been uploaded. Instead, they simply + // expire in 7 days. + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::integration::*; + use crate::tests::*; + use bytes::Bytes; + + #[tokio::test] + async fn azure_blob_test() { + maybe_skip_integration!(); + let integration = MicrosoftAzureBuilder::from_env().build().unwrap(); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + put_opts(&integration, true).await; + multipart(&integration, &integration).await; + multipart_race_condition(&integration, false).await; + multipart_out_of_order(&integration).await; + signing(&integration).await; + + let validate = !integration.client.config().disable_tagging; + tagging( + Arc::new(MicrosoftAzure { + client: Arc::clone(&integration.client), + }), + validate, + |p| { + let client = Arc::clone(&integration.client); + async move { client.get_blob_tagging(&p).await } + }, + ) + .await; + + // Azurite doesn't support attributes properly + if !integration.client.config().is_emulator { + put_get_attributes(&integration).await; + } + } + + #[ignore = "Used for manual testing against a real storage account."] + #[tokio::test] + async fn test_user_delegation_key() { + let account = std::env::var("AZURE_ACCOUNT_NAME").unwrap(); + let container = std::env::var("AZURE_CONTAINER_NAME").unwrap(); + let client_id = std::env::var("AZURE_CLIENT_ID").unwrap(); + let client_secret = std::env::var("AZURE_CLIENT_SECRET").unwrap(); + let tenant_id = std::env::var("AZURE_TENANT_ID").unwrap(); + let integration = MicrosoftAzureBuilder::new() + .with_account(account) + .with_container_name(container) + .with_client_id(client_id) + .with_client_secret(client_secret) + .with_tenant_id(&tenant_id) + .build() + .unwrap(); + + let data = Bytes::from("hello world"); + let path = Path::from("file.txt"); + integration.put(&path, data.clone().into()).await.unwrap(); + + let signed = integration + .signed_url(Method::GET, &path, Duration::from_secs(60)) + .await + .unwrap(); + + let resp = reqwest::get(signed).await.unwrap(); + let loaded = resp.bytes().await.unwrap(); + + assert_eq!(data, loaded); + } + + #[test] + fn azure_test_config_get_value() { + let azure_client_id = "object_store:fake_access_key_id".to_string(); + let azure_storage_account_name = "object_store:fake_secret_key".to_string(); + let azure_storage_token = "object_store:fake_default_region".to_string(); + let builder = MicrosoftAzureBuilder::new() + .with_config(AzureConfigKey::ClientId, &azure_client_id) + .with_config(AzureConfigKey::AccountName, &azure_storage_account_name) + .with_config(AzureConfigKey::Token, &azure_storage_token); + + assert_eq!( + builder.get_config_value(&AzureConfigKey::ClientId).unwrap(), + azure_client_id + ); + assert_eq!( + builder + .get_config_value(&AzureConfigKey::AccountName) + .unwrap(), + azure_storage_account_name + ); + assert_eq!( + builder.get_config_value(&AzureConfigKey::Token).unwrap(), + azure_storage_token + ); + } +} diff --git a/src/buffered.rs b/src/buffered.rs new file mode 100644 index 0000000..a767cb6 --- /dev/null +++ b/src/buffered.rs @@ -0,0 +1,679 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for performing tokio-style buffered IO + +use crate::path::Path; +use crate::{ + Attributes, ObjectMeta, ObjectStore, PutMultipartOpts, PutOptions, PutPayloadMut, TagSet, + WriteMultipart, +}; +use bytes::Bytes; +use futures::future::{BoxFuture, FutureExt}; +use futures::ready; +use std::cmp::Ordering; +use std::io::{Error, ErrorKind, SeekFrom}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; + +/// The default buffer size used by [`BufReader`] +pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024; + +/// An async-buffered reader compatible with the tokio IO traits +/// +/// Internally this maintains a buffer of the requested size, and uses [`ObjectStore::get_range`] +/// to populate its internal buffer once depleted. This buffer is cleared on seek. +/// +/// Whilst simple, this interface will typically be outperformed by the native [`ObjectStore`] +/// methods that better map to the network APIs. This is because most object stores have +/// very [high first-byte latencies], on the order of 100-200ms, and so avoiding unnecessary +/// round-trips is critical to throughput. +/// +/// Systems looking to sequentially scan a file should instead consider using [`ObjectStore::get`], +/// or [`ObjectStore::get_opts`], or [`ObjectStore::get_range`] to read a particular range. +/// +/// Systems looking to read multiple ranges of a file should instead consider using +/// [`ObjectStore::get_ranges`], which will optimise the vectored IO. +/// +/// [high first-byte latencies]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance.html +pub struct BufReader { + /// The object store to fetch data from + store: Arc, + /// The size of the object + size: u64, + /// The path to the object + path: Path, + /// The current position in the object + cursor: u64, + /// The number of bytes to read in a single request + capacity: usize, + /// The buffered data if any + buffer: Buffer, +} + +impl std::fmt::Debug for BufReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufReader") + .field("path", &self.path) + .field("size", &self.size) + .field("capacity", &self.capacity) + .finish() + } +} + +enum Buffer { + Empty, + Pending(BoxFuture<'static, std::io::Result>), + Ready(Bytes), +} + +impl BufReader { + /// Create a new [`BufReader`] from the provided [`ObjectMeta`] and [`ObjectStore`] + pub fn new(store: Arc, meta: &ObjectMeta) -> Self { + Self::with_capacity(store, meta, DEFAULT_BUFFER_SIZE) + } + + /// Create a new [`BufReader`] from the provided [`ObjectMeta`], [`ObjectStore`], and `capacity` + pub fn with_capacity(store: Arc, meta: &ObjectMeta, capacity: usize) -> Self { + Self { + path: meta.location.clone(), + size: meta.size as _, + store, + capacity, + cursor: 0, + buffer: Buffer::Empty, + } + } + + fn poll_fill_buf_impl( + &mut self, + cx: &mut Context<'_>, + amnt: usize, + ) -> Poll> { + let buf = &mut self.buffer; + loop { + match buf { + Buffer::Empty => { + let store = Arc::clone(&self.store); + let path = self.path.clone(); + let start = self.cursor.min(self.size) as _; + let end = self.cursor.saturating_add(amnt as u64).min(self.size) as _; + + if start == end { + return Poll::Ready(Ok(&[])); + } + + *buf = Buffer::Pending(Box::pin(async move { + Ok(store.get_range(&path, start..end).await?) + })) + } + Buffer::Pending(fut) => match ready!(fut.poll_unpin(cx)) { + Ok(b) => *buf = Buffer::Ready(b), + Err(e) => return Poll::Ready(Err(e)), + }, + Buffer::Ready(r) => return Poll::Ready(Ok(r)), + } + } + } +} + +impl AsyncSeek for BufReader { + fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> { + self.cursor = match position { + SeekFrom::Start(offset) => offset, + SeekFrom::End(offset) => checked_add_signed(self.size, offset).ok_or_else(|| { + Error::new( + ErrorKind::InvalidInput, + format!( + "Seeking {offset} from end of {} byte file would result in overflow", + self.size + ), + ) + })?, + SeekFrom::Current(offset) => { + checked_add_signed(self.cursor, offset).ok_or_else(|| { + Error::new( + ErrorKind::InvalidInput, + format!( + "Seeking {offset} from current offset of {} would result in overflow", + self.cursor + ), + ) + })? + } + }; + self.buffer = Buffer::Empty; + Ok(()) + } + + fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(self.cursor)) + } +} + +impl AsyncRead for BufReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + out: &mut ReadBuf<'_>, + ) -> Poll> { + // Read the maximum of the internal buffer and `out` + let to_read = out.remaining().max(self.capacity); + let r = match ready!(self.poll_fill_buf_impl(cx, to_read)) { + Ok(buf) => { + let to_consume = out.remaining().min(buf.len()); + out.put_slice(&buf[..to_consume]); + self.consume(to_consume); + Ok(()) + } + Err(e) => Err(e), + }; + Poll::Ready(r) + } +} + +impl AsyncBufRead for BufReader { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let capacity = self.capacity; + self.get_mut().poll_fill_buf_impl(cx, capacity) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + match &mut self.buffer { + Buffer::Empty => assert_eq!(amt, 0, "cannot consume from empty buffer"), + Buffer::Ready(b) => match b.len().cmp(&amt) { + Ordering::Less => panic!("{amt} exceeds buffer sized of {}", b.len()), + Ordering::Greater => *b = b.slice(amt..), + Ordering::Equal => self.buffer = Buffer::Empty, + }, + Buffer::Pending(_) => panic!("cannot consume from pending buffer"), + } + self.cursor += amt as u64; + } +} + +/// An async buffered writer compatible with the tokio IO traits +/// +/// This writer adaptively uses [`ObjectStore::put`] or +/// [`ObjectStore::put_multipart`] depending on the amount of data that has +/// been written. +/// +/// Up to `capacity` bytes will be buffered in memory, and flushed on shutdown +/// using [`ObjectStore::put`]. If `capacity` is exceeded, data will instead be +/// streamed using [`ObjectStore::put_multipart`] +pub struct BufWriter { + capacity: usize, + max_concurrency: usize, + attributes: Option, + tags: Option, + extensions: Option<::http::Extensions>, + state: BufWriterState, + store: Arc, +} + +impl std::fmt::Debug for BufWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufWriter") + .field("capacity", &self.capacity) + .finish() + } +} + +enum BufWriterState { + /// Buffer up to capacity bytes + Buffer(Path, PutPayloadMut), + /// [`ObjectStore::put_multipart`] + Prepare(BoxFuture<'static, crate::Result>), + /// Write to a multipart upload + Write(Option), + /// [`ObjectStore::put`] + Flush(BoxFuture<'static, crate::Result<()>>), +} + +impl BufWriter { + /// Create a new [`BufWriter`] from the provided [`ObjectStore`] and [`Path`] + pub fn new(store: Arc, path: Path) -> Self { + Self::with_capacity(store, path, 10 * 1024 * 1024) + } + + /// Create a new [`BufWriter`] from the provided [`ObjectStore`], [`Path`] and `capacity` + pub fn with_capacity(store: Arc, path: Path, capacity: usize) -> Self { + Self { + capacity, + store, + max_concurrency: 8, + attributes: None, + tags: None, + extensions: None, + state: BufWriterState::Buffer(path, PutPayloadMut::new()), + } + } + + /// Override the maximum number of in-flight requests for this writer + /// + /// Defaults to 8 + pub fn with_max_concurrency(self, max_concurrency: usize) -> Self { + Self { + max_concurrency, + ..self + } + } + + /// Set the attributes of the uploaded object + pub fn with_attributes(self, attributes: Attributes) -> Self { + Self { + attributes: Some(attributes), + ..self + } + } + + /// Set the tags of the uploaded object + pub fn with_tags(self, tags: TagSet) -> Self { + Self { + tags: Some(tags), + ..self + } + } + + /// Set the extensions of the uploaded object + /// + /// Implementation-specific extensions. Intended for use by [`ObjectStore`] implementations + /// that need to pass context-specific information (like tracing spans) via trait methods. + /// + /// These extensions are ignored entirely by backends offered through this crate. + pub fn with_extensions(self, extensions: ::http::Extensions) -> Self { + Self { + extensions: Some(extensions), + ..self + } + } + + /// Write data to the writer in [`Bytes`]. + /// + /// Unlike [`AsyncWrite::poll_write`], `put` can write data without extra copying. + /// + /// This API is recommended while the data source generates [`Bytes`]. + pub async fn put(&mut self, bytes: Bytes) -> crate::Result<()> { + loop { + return match &mut self.state { + BufWriterState::Write(Some(write)) => { + write.wait_for_capacity(self.max_concurrency).await?; + write.put(bytes); + Ok(()) + } + BufWriterState::Write(None) | BufWriterState::Flush(_) => { + panic!("Already shut down") + } + // NOTE + // + // This case should never happen in practice, but rust async API does + // make it possible for users to call `put` before `poll_write` returns `Ready`. + // + // We allow such usage by `await` the future and continue the loop. + BufWriterState::Prepare(f) => { + self.state = BufWriterState::Write(f.await?.into()); + continue; + } + BufWriterState::Buffer(path, b) => { + if b.content_length().saturating_add(bytes.len()) < self.capacity { + b.push(bytes); + Ok(()) + } else { + let buffer = std::mem::take(b); + let path = std::mem::take(path); + let opts = PutMultipartOpts { + attributes: self.attributes.take().unwrap_or_default(), + tags: self.tags.take().unwrap_or_default(), + extensions: self.extensions.take().unwrap_or_default(), + }; + let upload = self.store.put_multipart_opts(&path, opts).await?; + let mut chunked = + WriteMultipart::new_with_chunk_size(upload, self.capacity); + for chunk in buffer.freeze() { + chunked.put(chunk); + } + chunked.put(bytes); + self.state = BufWriterState::Write(Some(chunked)); + Ok(()) + } + } + }; + } + } + + /// Abort this writer, cleaning up any partially uploaded state + /// + /// # Panic + /// + /// Panics if this writer has already been shutdown or aborted + pub async fn abort(&mut self) -> crate::Result<()> { + match &mut self.state { + BufWriterState::Buffer(_, _) | BufWriterState::Prepare(_) => Ok(()), + BufWriterState::Flush(_) => panic!("Already shut down"), + BufWriterState::Write(x) => x.take().unwrap().abort().await, + } + } +} + +impl AsyncWrite for BufWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let cap = self.capacity; + let max_concurrency = self.max_concurrency; + loop { + return match &mut self.state { + BufWriterState::Write(Some(write)) => { + ready!(write.poll_for_capacity(cx, max_concurrency))?; + write.write(buf); + Poll::Ready(Ok(buf.len())) + } + BufWriterState::Write(None) | BufWriterState::Flush(_) => { + panic!("Already shut down") + } + BufWriterState::Prepare(f) => { + self.state = BufWriterState::Write(ready!(f.poll_unpin(cx)?).into()); + continue; + } + BufWriterState::Buffer(path, b) => { + if b.content_length().saturating_add(buf.len()) >= cap { + let buffer = std::mem::take(b); + let path = std::mem::take(path); + let opts = PutMultipartOpts { + attributes: self.attributes.take().unwrap_or_default(), + tags: self.tags.take().unwrap_or_default(), + extensions: self.extensions.take().unwrap_or_default(), + }; + let store = Arc::clone(&self.store); + self.state = BufWriterState::Prepare(Box::pin(async move { + let upload = store.put_multipart_opts(&path, opts).await?; + let mut chunked = WriteMultipart::new_with_chunk_size(upload, cap); + for chunk in buffer.freeze() { + chunked.put(chunk); + } + Ok(chunked) + })); + continue; + } + b.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + }; + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + return match &mut self.state { + BufWriterState::Write(_) | BufWriterState::Buffer(_, _) => Poll::Ready(Ok(())), + BufWriterState::Flush(_) => panic!("Already shut down"), + BufWriterState::Prepare(f) => { + self.state = BufWriterState::Write(ready!(f.poll_unpin(cx)?).into()); + continue; + } + }; + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match &mut self.state { + BufWriterState::Prepare(f) => { + self.state = BufWriterState::Write(ready!(f.poll_unpin(cx)?).into()); + } + BufWriterState::Buffer(p, b) => { + let buf = std::mem::take(b); + let path = std::mem::take(p); + let opts = PutOptions { + attributes: self.attributes.take().unwrap_or_default(), + tags: self.tags.take().unwrap_or_default(), + ..Default::default() + }; + let store = Arc::clone(&self.store); + self.state = BufWriterState::Flush(Box::pin(async move { + store.put_opts(&path, buf.into(), opts).await?; + Ok(()) + })); + } + BufWriterState::Flush(f) => return f.poll_unpin(cx).map_err(std::io::Error::from), + BufWriterState::Write(x) => { + let upload = x.take().ok_or_else(|| { + std::io::Error::new( + ErrorKind::InvalidInput, + "Cannot shutdown a writer that has already been shut down", + ) + })?; + self.state = BufWriterState::Flush( + async move { + upload.finish().await?; + Ok(()) + } + .boxed(), + ) + } + } + } + } +} + +/// Port of standardised function as requires Rust 1.66 +/// +/// +#[inline] +fn checked_add_signed(a: u64, rhs: i64) -> Option { + let (res, overflowed) = a.overflowing_add(rhs as _); + let overflow = overflowed ^ (rhs < 0); + (!overflow).then_some(res) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::InMemory; + use crate::path::Path; + use crate::{Attribute, GetOptions}; + use itertools::Itertools; + use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; + + #[tokio::test] + async fn test_buf_reader() { + let store = Arc::new(InMemory::new()) as Arc; + + let existent = Path::from("exists.txt"); + const BYTES: usize = 4096; + + let data: Bytes = b"12345678".iter().cycle().copied().take(BYTES).collect(); + store.put(&existent, data.clone().into()).await.unwrap(); + + let meta = store.head(&existent).await.unwrap(); + + let mut reader = BufReader::new(Arc::clone(&store), &meta); + let mut out = Vec::with_capacity(BYTES); + let read = reader.read_to_end(&mut out).await.unwrap(); + + assert_eq!(read, BYTES); + assert_eq!(&out, &data); + + let err = reader.seek(SeekFrom::Current(i64::MIN)).await.unwrap_err(); + assert_eq!( + err.to_string(), + "Seeking -9223372036854775808 from current offset of 4096 would result in overflow" + ); + + reader.rewind().await.unwrap(); + + let err = reader.seek(SeekFrom::Current(-1)).await.unwrap_err(); + assert_eq!( + err.to_string(), + "Seeking -1 from current offset of 0 would result in overflow" + ); + + // Seeking beyond the bounds of the file is permitted but should return no data + reader.seek(SeekFrom::Start(u64::MAX)).await.unwrap(); + let buf = reader.fill_buf().await.unwrap(); + assert!(buf.is_empty()); + + let err = reader.seek(SeekFrom::Current(1)).await.unwrap_err(); + assert_eq!( + err.to_string(), + "Seeking 1 from current offset of 18446744073709551615 would result in overflow" + ); + + for capacity in [200, 1024, 4096, DEFAULT_BUFFER_SIZE] { + let store = Arc::clone(&store); + let mut reader = BufReader::with_capacity(store, &meta, capacity); + + let mut bytes_read = 0; + loop { + let buf = reader.fill_buf().await.unwrap(); + if buf.is_empty() { + assert_eq!(bytes_read, BYTES); + break; + } + assert!(buf.starts_with(b"12345678")); + bytes_read += 8; + reader.consume(8); + } + + let mut buf = Vec::with_capacity(76); + reader.seek(SeekFrom::Current(-76)).await.unwrap(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(&buf, &data[BYTES - 76..]); + + reader.rewind().await.unwrap(); + let buffer = reader.fill_buf().await.unwrap(); + assert_eq!(buffer, &data[..capacity.min(BYTES)]); + + reader.seek(SeekFrom::Start(325)).await.unwrap(); + let buffer = reader.fill_buf().await.unwrap(); + assert_eq!(buffer, &data[325..(325 + capacity).min(BYTES)]); + + reader.seek(SeekFrom::End(0)).await.unwrap(); + let buffer = reader.fill_buf().await.unwrap(); + assert!(buffer.is_empty()); + } + } + + // Note: `BufWriter::with_tags` functionality is tested in `crate::tests::tagging` + #[tokio::test] + async fn test_buf_writer() { + let store = Arc::new(InMemory::new()) as Arc; + let path = Path::from("file.txt"); + let attributes = Attributes::from_iter([ + (Attribute::ContentType, "text/html"), + (Attribute::CacheControl, "max-age=604800"), + ]); + + // Test put + let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30) + .with_attributes(attributes.clone()); + writer.write_all(&[0; 20]).await.unwrap(); + writer.flush().await.unwrap(); + writer.write_all(&[0; 5]).await.unwrap(); + writer.shutdown().await.unwrap(); + let response = store + .get_opts( + &path, + GetOptions { + head: true, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(response.meta.size, 25); + assert_eq!(response.attributes, attributes); + + // Test multipart + let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30) + .with_attributes(attributes.clone()); + writer.write_all(&[0; 20]).await.unwrap(); + writer.flush().await.unwrap(); + writer.write_all(&[0; 20]).await.unwrap(); + writer.shutdown().await.unwrap(); + let response = store + .get_opts( + &path, + GetOptions { + head: true, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(response.meta.size, 40); + assert_eq!(response.attributes, attributes); + } + + #[tokio::test] + async fn test_buf_writer_with_put() { + let store = Arc::new(InMemory::new()) as Arc; + let path = Path::from("file.txt"); + + // Test put + let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30); + writer + .put(Bytes::from((0..20).collect_vec())) + .await + .unwrap(); + writer + .put(Bytes::from((20..25).collect_vec())) + .await + .unwrap(); + writer.shutdown().await.unwrap(); + let response = store + .get_opts( + &path, + GetOptions { + head: true, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(response.meta.size, 25); + assert_eq!(response.bytes().await.unwrap(), (0..25).collect_vec()); + + // Test multipart + let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30); + writer + .put(Bytes::from((0..20).collect_vec())) + .await + .unwrap(); + writer + .put(Bytes::from((20..40).collect_vec())) + .await + .unwrap(); + writer.shutdown().await.unwrap(); + let response = store + .get_opts( + &path, + GetOptions { + head: true, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(response.meta.size, 40); + assert_eq!(response.bytes().await.unwrap(), (0..40).collect_vec()); + } +} diff --git a/src/chunked.rs b/src/chunked.rs new file mode 100644 index 0000000..2bb30b9 --- /dev/null +++ b/src/chunked.rs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A [`ChunkedStore`] that can be used to test streaming behaviour + +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Range; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::stream::BoxStream; +use futures::StreamExt; + +use crate::path::Path; +use crate::{ + GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOpts, PutOptions, PutResult, +}; +use crate::{PutPayload, Result}; + +/// Wraps a [`ObjectStore`] and makes its get response return chunks +/// in a controllable manner. +/// +/// A `ChunkedStore` makes the memory consumption and performance of +/// the wrapped [`ObjectStore`] worse. It is intended for use within +/// tests, to control the chunks in the produced output streams. For +/// example, it is used to verify the delimiting logic in +/// newline_delimited_stream. +#[derive(Debug)] +pub struct ChunkedStore { + inner: Arc, + chunk_size: usize, // chunks are in memory, so we use usize not u64 +} + +impl ChunkedStore { + /// Creates a new [`ChunkedStore`] with the specified chunk_size + pub fn new(inner: Arc, chunk_size: usize) -> Self { + Self { inner, chunk_size } + } +} + +impl Display for ChunkedStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ChunkedStore({})", self.inner) + } +} + +#[async_trait] +impl ObjectStore for ChunkedStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart(&self, location: &Path) -> Result> { + self.inner.put_multipart(location).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + self.inner.put_multipart_opts(location, opts).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + let r = self.inner.get_opts(location, options).await?; + let stream = match r.payload { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + GetResultPayload::File(file, path) => { + crate::local::chunked_stream(file, path, r.range.clone(), self.chunk_size) + } + GetResultPayload::Stream(stream) => { + let buffer = BytesMut::new(); + futures::stream::unfold( + (stream, buffer, false, self.chunk_size), + |(mut stream, mut buffer, mut exhausted, chunk_size)| async move { + // Keep accumulating bytes until we reach capacity as long as + // the stream can provide them: + if exhausted { + return None; + } + while buffer.len() < chunk_size { + match stream.next().await { + None => { + exhausted = true; + let slice = buffer.split_off(0).freeze(); + return Some(( + Ok(slice), + (stream, buffer, exhausted, chunk_size), + )); + } + Some(Ok(bytes)) => { + buffer.put(bytes); + } + Some(Err(e)) => { + return Some(( + Err(crate::Error::Generic { + store: "ChunkedStore", + source: Box::new(e), + }), + (stream, buffer, exhausted, chunk_size), + )) + } + }; + } + // Return the chunked values as the next value in the stream + let slice = buffer.split_to(chunk_size).freeze(); + Some((Ok(slice), (stream, buffer, exhausted, chunk_size))) + }, + ) + .boxed() + } + }; + Ok(GetResult { + payload: GetResultPayload::Stream(stream), + ..r + }) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + self.inner.get_range(location, range).await + } + + async fn head(&self, location: &Path) -> Result { + self.inner.head(location).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.inner.delete(location).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.inner.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + self.inner.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.inner.copy(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.inner.copy_if_not_exists(from, to).await + } +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + + #[cfg(feature = "fs")] + use crate::integration::*; + #[cfg(feature = "fs")] + use crate::local::LocalFileSystem; + use crate::memory::InMemory; + use crate::path::Path; + + use super::*; + + #[tokio::test] + async fn test_chunked_basic() { + let location = Path::parse("test").unwrap(); + let store: Arc = Arc::new(InMemory::new()); + store.put(&location, vec![0; 1001].into()).await.unwrap(); + + for chunk_size in [10, 20, 31] { + let store = ChunkedStore::new(Arc::clone(&store), chunk_size); + let mut s = match store.get(&location).await.unwrap().payload { + GetResultPayload::Stream(s) => s, + _ => unreachable!(), + }; + + let mut remaining = 1001; + while let Some(next) = s.next().await { + let size = next.unwrap().len() as u64; + let expected = remaining.min(chunk_size as u64); + assert_eq!(size, expected); + remaining -= expected; + } + assert_eq!(remaining, 0); + } + } + + #[cfg(feature = "fs")] + #[tokio::test] + async fn test_chunked() { + let temporary = tempfile::tempdir().unwrap(); + let integrations: &[Arc] = &[ + Arc::new(InMemory::new()), + Arc::new(LocalFileSystem::new_with_prefix(temporary.path()).unwrap()), + ]; + + for integration in integrations { + let integration = ChunkedStore::new(Arc::clone(integration), 100); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + } +} diff --git a/src/client/backoff.rs b/src/client/backoff.rs new file mode 100644 index 0000000..8382a2e --- /dev/null +++ b/src/client/backoff.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use rand::prelude::*; +use std::time::Duration; + +/// Exponential backoff with decorrelated jitter algorithm +/// +/// The first backoff will always be `init_backoff`. +/// +/// Subsequent backoffs will pick a random value between `init_backoff` and +/// `base * previous` where `previous` is the duration of the previous backoff +/// +/// See +#[allow(missing_copy_implementations)] +#[derive(Debug, Clone)] +pub struct BackoffConfig { + /// The initial backoff duration + pub init_backoff: Duration, + /// The maximum backoff duration + pub max_backoff: Duration, + /// The multiplier to use for the next backoff duration + pub base: f64, +} + +impl Default for BackoffConfig { + fn default() -> Self { + Self { + init_backoff: Duration::from_millis(100), + max_backoff: Duration::from_secs(15), + base: 2., + } + } +} + +/// [`Backoff`] can be created from a [`BackoffConfig`] +/// +/// Consecutive calls to [`Backoff::next`] will return the next backoff interval +/// +pub(crate) struct Backoff { + init_backoff: f64, + next_backoff_secs: f64, + max_backoff_secs: f64, + base: f64, + rng: Option>, +} + +impl std::fmt::Debug for Backoff { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Backoff") + .field("init_backoff", &self.init_backoff) + .field("next_backoff_secs", &self.next_backoff_secs) + .field("max_backoff_secs", &self.max_backoff_secs) + .field("base", &self.base) + .finish() + } +} + +impl Backoff { + /// Create a new [`Backoff`] from the provided [`BackoffConfig`] + pub(crate) fn new(config: &BackoffConfig) -> Self { + Self::new_with_rng(config, None) + } + + /// Creates a new `Backoff` with the optional `rng` + /// + /// Used [`rand::thread_rng()`] if no rng provided + pub(crate) fn new_with_rng( + config: &BackoffConfig, + rng: Option>, + ) -> Self { + let init_backoff = config.init_backoff.as_secs_f64(); + Self { + init_backoff, + next_backoff_secs: init_backoff, + max_backoff_secs: config.max_backoff.as_secs_f64(), + base: config.base, + rng, + } + } + + /// Returns the next backoff duration to wait for + pub(crate) fn next(&mut self) -> Duration { + let range = self.init_backoff..(self.next_backoff_secs * self.base); + + let rand_backoff = match self.rng.as_mut() { + Some(rng) => rng.gen_range(range), + None => thread_rng().gen_range(range), + }; + + let next_backoff = self.max_backoff_secs.min(rand_backoff); + Duration::from_secs_f64(std::mem::replace(&mut self.next_backoff_secs, next_backoff)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::mock::StepRng; + + #[test] + fn test_backoff() { + let init_backoff_secs = 1.; + let max_backoff_secs = 500.; + let base = 3.; + + let config = BackoffConfig { + init_backoff: Duration::from_secs_f64(init_backoff_secs), + max_backoff: Duration::from_secs_f64(max_backoff_secs), + base, + }; + + let assert_fuzzy_eq = |a: f64, b: f64| assert!((b - a).abs() < 0.0001, "{a} != {b}"); + + // Create a static rng that takes the minimum of the range + let rng = Box::new(StepRng::new(0, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + for _ in 0..20 { + assert_eq!(backoff.next().as_secs_f64(), init_backoff_secs); + } + + // Create a static rng that takes the maximum of the range + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + for i in 0..20 { + let value = (base.powi(i) * init_backoff_secs).min(max_backoff_secs); + assert_fuzzy_eq(backoff.next().as_secs_f64(), value); + } + + // Create a static rng that takes the mid point of the range + let rng = Box::new(StepRng::new(u64::MAX / 2, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + let mut value = init_backoff_secs; + for _ in 0..20 { + assert_fuzzy_eq(backoff.next().as_secs_f64(), value); + value = + (init_backoff_secs + (value * base - init_backoff_secs) / 2.).min(max_backoff_secs); + } + } +} diff --git a/src/client/body.rs b/src/client/body.rs new file mode 100644 index 0000000..8f62afa --- /dev/null +++ b/src/client/body.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::connection::{HttpError, HttpErrorKind}; +use crate::{collect_bytes, PutPayload}; +use bytes::Bytes; +use futures::stream::BoxStream; +use futures::StreamExt; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; +use hyper::body::{Body, Frame, SizeHint}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// An HTTP Request +pub type HttpRequest = http::Request; + +/// The [`Body`] of an [`HttpRequest`] +#[derive(Debug, Clone)] +pub struct HttpRequestBody(Inner); + +impl HttpRequestBody { + /// An empty [`HttpRequestBody`] + pub fn empty() -> Self { + Self(Inner::Bytes(Bytes::new())) + } + + #[cfg(not(target_arch = "wasm32"))] + pub(crate) fn into_reqwest(self) -> reqwest::Body { + match self.0 { + Inner::Bytes(b) => b.into(), + Inner::PutPayload(_, payload) => reqwest::Body::wrap_stream(futures::stream::iter( + payload.into_iter().map(Ok::<_, HttpError>), + )), + } + } + + /// Returns true if this body is empty + pub fn is_empty(&self) -> bool { + match &self.0 { + Inner::Bytes(x) => x.is_empty(), + Inner::PutPayload(_, x) => x.iter().any(|x| !x.is_empty()), + } + } + + /// Returns the total length of the [`Bytes`] in this body + pub fn content_length(&self) -> usize { + match &self.0 { + Inner::Bytes(x) => x.len(), + Inner::PutPayload(_, x) => x.content_length(), + } + } + + /// If this body consists of a single contiguous [`Bytes`], returns it + pub fn as_bytes(&self) -> Option<&Bytes> { + match &self.0 { + Inner::Bytes(x) => Some(x), + _ => None, + } + } +} + +impl From for HttpRequestBody { + fn from(value: Bytes) -> Self { + Self(Inner::Bytes(value)) + } +} + +impl From> for HttpRequestBody { + fn from(value: Vec) -> Self { + Self(Inner::Bytes(value.into())) + } +} + +impl From for HttpRequestBody { + fn from(value: String) -> Self { + Self(Inner::Bytes(value.into())) + } +} + +impl From for HttpRequestBody { + fn from(value: PutPayload) -> Self { + Self(Inner::PutPayload(0, value)) + } +} + +#[derive(Debug, Clone)] +enum Inner { + Bytes(Bytes), + PutPayload(usize, PutPayload), +} + +impl Body for HttpRequestBody { + type Data = Bytes; + type Error = HttpError; + + fn poll_frame( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Poll::Ready(match &mut self.0 { + Inner::Bytes(bytes) => { + let out = bytes.split_off(0); + if out.is_empty() { + None + } else { + Some(Ok(Frame::data(out))) + } + } + Inner::PutPayload(offset, payload) => { + let slice = payload.as_ref(); + if *offset == slice.len() { + None + } else { + Some(Ok(Frame::data( + slice[std::mem::replace(offset, *offset + 1)].clone(), + ))) + } + } + }) + } + + fn is_end_stream(&self) -> bool { + match self.0 { + Inner::Bytes(ref bytes) => bytes.is_empty(), + Inner::PutPayload(offset, ref body) => offset == body.as_ref().len(), + } + } + + fn size_hint(&self) -> SizeHint { + match self.0 { + Inner::Bytes(ref bytes) => SizeHint::with_exact(bytes.len() as u64), + Inner::PutPayload(offset, ref payload) => { + let iter = payload.as_ref().iter().skip(offset); + SizeHint::with_exact(iter.map(|x| x.len() as u64).sum()) + } + } + } +} + +/// An HTTP response +pub type HttpResponse = http::Response; + +/// The body of an [`HttpResponse`] +#[derive(Debug)] +pub struct HttpResponseBody(BoxBody); + +impl HttpResponseBody { + /// Create an [`HttpResponseBody`] from the provided [`Body`] + /// + /// Note: [`BodyExt::map_err`] can be used to alter error variants + pub fn new(body: B) -> Self + where + B: Body + Send + Sync + 'static, + { + Self(BoxBody::new(body)) + } + + /// Collects this response into a [`Bytes`] + pub async fn bytes(self) -> Result { + let size_hint = self.0.size_hint().lower(); + let s = self.0.into_data_stream(); + collect_bytes(s, Some(size_hint)).await + } + + /// Returns a stream of this response data + pub fn bytes_stream(self) -> BoxStream<'static, Result> { + self.0.into_data_stream().boxed() + } + + /// Returns the response as a [`String`] + pub(crate) async fn text(self) -> Result { + let b = self.bytes().await?; + String::from_utf8(b.into()).map_err(|e| HttpError::new(HttpErrorKind::Decode, e)) + } + + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) async fn json(self) -> Result { + let b = self.bytes().await?; + serde_json::from_slice(&b).map_err(|e| HttpError::new(HttpErrorKind::Decode, e)) + } +} + +impl From for HttpResponseBody { + fn from(value: Bytes) -> Self { + Self::new(Full::new(value).map_err(|e| match e {})) + } +} + +impl From> for HttpResponseBody { + fn from(value: Vec) -> Self { + Bytes::from(value).into() + } +} + +impl From for HttpResponseBody { + fn from(value: String) -> Self { + Bytes::from(value).into() + } +} diff --git a/src/client/builder.rs b/src/client/builder.rs new file mode 100644 index 0000000..fcbc6e8 --- /dev/null +++ b/src/client/builder.rs @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::connection::HttpErrorKind; +use crate::client::{HttpClient, HttpError, HttpRequest, HttpRequestBody}; +use http::header::{InvalidHeaderName, InvalidHeaderValue}; +use http::uri::InvalidUri; +use http::{HeaderName, HeaderValue, Method, Uri}; + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RequestBuilderError { + #[error("Invalid URI")] + InvalidUri(#[from] InvalidUri), + + #[error("Invalid Header Value")] + InvalidHeaderValue(#[from] InvalidHeaderValue), + + #[error("Invalid Header Name")] + InvalidHeaderName(#[from] InvalidHeaderName), + + #[error("JSON serialization error")] + SerdeJson(#[from] serde_json::Error), + + #[error("URL serialization error")] + SerdeUrl(#[from] serde_urlencoded::ser::Error), +} + +impl From for HttpError { + fn from(value: RequestBuilderError) -> Self { + Self::new(HttpErrorKind::Request, value) + } +} + +impl From for RequestBuilderError { + fn from(value: std::convert::Infallible) -> Self { + match value {} + } +} + +pub(crate) struct HttpRequestBuilder { + client: HttpClient, + request: Result, +} + +impl HttpRequestBuilder { + pub(crate) fn new(client: HttpClient) -> Self { + Self { + client, + request: Ok(HttpRequest::new(HttpRequestBody::empty())), + } + } + + #[cfg(any(feature = "aws", feature = "azure"))] + pub(crate) fn from_parts(client: HttpClient, request: HttpRequest) -> Self { + Self { + client, + request: Ok(request), + } + } + + pub(crate) fn method(mut self, method: Method) -> Self { + if let Ok(r) = &mut self.request { + *r.method_mut() = method; + } + self + } + + pub(crate) fn uri(mut self, url: U) -> Self + where + U: TryInto, + U::Error: Into, + { + match (url.try_into(), &mut self.request) { + (Ok(uri), Ok(r)) => *r.uri_mut() = uri, + (Err(e), Ok(_)) => self.request = Err(e.into()), + (_, Err(_)) => {} + } + self + } + + pub(crate) fn extensions(mut self, extensions: ::http::Extensions) -> Self { + if let Ok(r) = &mut self.request { + *r.extensions_mut() = extensions; + } + self + } + + pub(crate) fn header(mut self, name: K, value: V) -> Self + where + K: TryInto, + K::Error: Into, + V: TryInto, + V::Error: Into, + { + match (name.try_into(), value.try_into(), &mut self.request) { + (Ok(name), Ok(value), Ok(r)) => { + r.headers_mut().insert(name, value); + } + (Err(e), _, Ok(_)) => self.request = Err(e.into()), + (_, Err(e), Ok(_)) => self.request = Err(e.into()), + (_, _, Err(_)) => {} + } + self + } + + #[cfg(feature = "aws")] + pub(crate) fn headers(mut self, headers: http::HeaderMap) -> Self { + use http::header::{Entry, OccupiedEntry}; + + if let Ok(ref mut req) = self.request { + // IntoIter of HeaderMap yields (Option, HeaderValue). + // The first time a name is yielded, it will be Some(name), and if + // there are more values with the same name, the next yield will be + // None. + + let mut prev_entry: Option> = None; + for (key, value) in headers { + match key { + Some(key) => match req.headers_mut().entry(key) { + Entry::Occupied(mut e) => { + e.insert(value); + prev_entry = Some(e); + } + Entry::Vacant(e) => { + let e = e.insert_entry(value); + prev_entry = Some(e); + } + }, + None => match prev_entry { + Some(ref mut entry) => { + entry.append(value); + } + None => unreachable!("HeaderMap::into_iter yielded None first"), + }, + } + } + } + self + } + + #[cfg(feature = "gcp")] + pub(crate) fn bearer_auth(mut self, token: &str) -> Self { + let value = HeaderValue::try_from(format!("Bearer {}", token)); + match (value, &mut self.request) { + (Ok(mut v), Ok(r)) => { + v.set_sensitive(true); + r.headers_mut().insert(http::header::AUTHORIZATION, v); + } + (Err(e), Ok(_)) => self.request = Err(e.into()), + (_, Err(_)) => {} + } + self + } + + #[cfg(any(feature = "aws", feature = "gcp"))] + pub(crate) fn json(mut self, s: S) -> Self { + match (serde_json::to_vec(&s), &mut self.request) { + (Ok(json), Ok(request)) => { + *request.body_mut() = json.into(); + } + (Err(e), Ok(_)) => self.request = Err(e.into()), + (_, Err(_)) => {} + } + self + } + + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) fn query(mut self, query: &T) -> Self { + let mut error = None; + if let Ok(ref mut req) = self.request { + let mut out = format!("{}?", req.uri().path()); + let mut encoder = form_urlencoded::Serializer::new(&mut out); + let serializer = serde_urlencoded::Serializer::new(&mut encoder); + + if let Err(err) = query.serialize(serializer) { + error = Some(err.into()); + } + + match http::uri::PathAndQuery::from_maybe_shared(out) { + Ok(p) => { + let mut parts = req.uri().clone().into_parts(); + parts.path_and_query = Some(p); + *req.uri_mut() = Uri::from_parts(parts).unwrap(); + } + Err(err) => error = Some(err.into()), + } + } + if let Some(err) = error { + self.request = Err(err); + } + self + } + + #[cfg(any(feature = "gcp", feature = "azure"))] + pub(crate) fn form(mut self, form: T) -> Self { + let mut error = None; + if let Ok(ref mut req) = self.request { + match serde_urlencoded::to_string(form) { + Ok(body) => { + req.headers_mut().insert( + http::header::CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + ); + *req.body_mut() = body.into(); + } + Err(err) => error = Some(err.into()), + } + } + if let Some(err) = error { + self.request = Err(err); + } + self + } + + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) fn body(mut self, b: impl Into) -> Self { + if let Ok(r) = &mut self.request { + *r.body_mut() = b.into(); + } + self + } + + pub(crate) fn into_parts(self) -> (HttpClient, Result) { + (self.client, self.request) + } +} + +#[cfg(any(test, feature = "azure"))] +pub(crate) fn add_query_pairs(uri: &mut Uri, query_pairs: I) +where + I: IntoIterator, + I::Item: std::borrow::Borrow<(K, V)>, + K: AsRef, + V: AsRef, +{ + let mut parts = uri.clone().into_parts(); + + let mut out = match parts.path_and_query { + Some(p) => match p.query() { + Some(x) => format!("{}?{}", p.path(), x), + None => format!("{}?", p.path()), + }, + None => "/?".to_string(), + }; + let mut serializer = form_urlencoded::Serializer::new(&mut out); + serializer.extend_pairs(query_pairs); + + parts.path_and_query = Some(out.try_into().unwrap()); + *uri = Uri::from_parts(parts).unwrap(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_query_pairs() { + let mut uri = Uri::from_static("https://foo@example.com/bananas?foo=1"); + + add_query_pairs(&mut uri, [("bingo", "foo"), ("auth", "test")]); + assert_eq!( + uri.to_string(), + "https://foo@example.com/bananas?foo=1&bingo=foo&auth=test" + ); + + add_query_pairs(&mut uri, [("t1", "funky shenanigans"), ("a", "😀")]); + assert_eq!( + uri.to_string(), + "https://foo@example.com/bananas?foo=1&bingo=foo&auth=test&t1=funky+shenanigans&a=%F0%9F%98%80" + ); + } +} diff --git a/src/client/connection.rs b/src/client/connection.rs new file mode 100644 index 0000000..7e2daf4 --- /dev/null +++ b/src/client/connection.rs @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::body::{HttpRequest, HttpResponse}; +use crate::client::builder::{HttpRequestBuilder, RequestBuilderError}; +use crate::client::HttpResponseBody; +use crate::ClientOptions; +use async_trait::async_trait; +use http::{Method, Uri}; +use http_body_util::BodyExt; +use std::error::Error; +use std::sync::Arc; + +/// An HTTP protocol error +/// +/// Clients should return this when an HTTP request fails to be completed, e.g. because +/// of a connection issue. This does **not** include HTTP requests that are return +/// non 2xx Status Codes, as these should instead be returned as an [`HttpResponse`] +/// with the appropriate status code set. +#[derive(Debug, thiserror::Error)] +#[error("HTTP error: {source}")] +pub struct HttpError { + kind: HttpErrorKind, + #[source] + source: Box, +} + +/// Identifies the kind of [`HttpError`] +/// +/// This is used, among other things, to determine if a request can be retried +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum HttpErrorKind { + /// An error occurred whilst connecting to the remote + /// + /// Will be automatically retried + Connect, + /// An error occurred whilst making the request + /// + /// Will be automatically retried + Request, + /// Request timed out + /// + /// Will be automatically retried if the request is idempotent + Timeout, + /// The request was aborted + /// + /// Will be automatically retried if the request is idempotent + Interrupted, + /// An error occurred whilst decoding the response + /// + /// Will not be automatically retried + Decode, + /// An unknown error occurred + /// + /// Will not be automatically retried + Unknown, +} + +impl HttpError { + /// Create a new [`HttpError`] with the optional status code + pub fn new(kind: HttpErrorKind, e: E) -> Self + where + E: Error + Send + Sync + 'static, + { + Self { + kind, + source: Box::new(e), + } + } + + pub(crate) fn reqwest(e: reqwest::Error) -> Self { + #[cfg(not(target_arch = "wasm32"))] + let is_connect = || e.is_connect(); + #[cfg(target_arch = "wasm32")] + let is_connect = || false; + + let mut kind = if e.is_timeout() { + HttpErrorKind::Timeout + } else if is_connect() { + HttpErrorKind::Connect + } else if e.is_decode() { + HttpErrorKind::Decode + } else { + HttpErrorKind::Unknown + }; + + // Reqwest error variants aren't great, attempt to refine them + let mut source = e.source(); + while let Some(e) = source { + if let Some(e) = e.downcast_ref::() { + if e.is_closed() || e.is_incomplete_message() || e.is_body_write_aborted() { + kind = HttpErrorKind::Request; + } else if e.is_timeout() { + kind = HttpErrorKind::Timeout; + } + break; + } + if let Some(e) = e.downcast_ref::() { + match e.kind() { + std::io::ErrorKind::TimedOut => kind = HttpErrorKind::Timeout, + std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::UnexpectedEof => kind = HttpErrorKind::Interrupted, + _ => {} + } + break; + } + source = e.source(); + } + Self { + kind, + // We strip URL as it will be included by RetryError if not sensitive + source: Box::new(e.without_url()), + } + } + + /// Returns the [`HttpErrorKind`] + pub fn kind(&self) -> HttpErrorKind { + self.kind + } +} + +/// An asynchronous function from a [`HttpRequest`] to a [`HttpResponse`]. +#[async_trait] +pub trait HttpService: std::fmt::Debug + Send + Sync + 'static { + /// Perform [`HttpRequest`] returning [`HttpResponse`] + async fn call(&self, req: HttpRequest) -> Result; +} + +/// An HTTP client +#[derive(Debug, Clone)] +pub struct HttpClient(Arc); + +impl HttpClient { + /// Create a new [`HttpClient`] from an [`HttpService`] + pub fn new(service: impl HttpService + 'static) -> Self { + Self(Arc::new(service)) + } + + /// Performs [`HttpRequest`] using this client + pub async fn execute(&self, request: HttpRequest) -> Result { + self.0.call(request).await + } + + #[allow(unused)] + pub(crate) fn get(&self, url: U) -> HttpRequestBuilder + where + U: TryInto, + U::Error: Into, + { + self.request(Method::GET, url) + } + + #[allow(unused)] + pub(crate) fn post(&self, url: U) -> HttpRequestBuilder + where + U: TryInto, + U::Error: Into, + { + self.request(Method::POST, url) + } + + #[allow(unused)] + pub(crate) fn put(&self, url: U) -> HttpRequestBuilder + where + U: TryInto, + U::Error: Into, + { + self.request(Method::PUT, url) + } + + #[allow(unused)] + pub(crate) fn delete(&self, url: U) -> HttpRequestBuilder + where + U: TryInto, + U::Error: Into, + { + self.request(Method::DELETE, url) + } + + pub(crate) fn request(&self, method: Method, url: U) -> HttpRequestBuilder + where + U: TryInto, + U::Error: Into, + { + HttpRequestBuilder::new(self.clone()) + .uri(url) + .method(method) + } +} + +#[async_trait] +#[cfg(not(target_arch = "wasm32"))] +impl HttpService for reqwest::Client { + async fn call(&self, req: HttpRequest) -> Result { + let (parts, body) = req.into_parts(); + + let url = parts.uri.to_string().parse().unwrap(); + let mut req = reqwest::Request::new(parts.method, url); + *req.headers_mut() = parts.headers; + *req.body_mut() = Some(body.into_reqwest()); + + let r = self.execute(req).await.map_err(HttpError::reqwest)?; + let res: http::Response = r.into(); + let (parts, body) = res.into_parts(); + + let body = HttpResponseBody::new(body.map_err(HttpError::reqwest)); + Ok(HttpResponse::from_parts(parts, body)) + } +} + +/// A factory for [`HttpClient`] +pub trait HttpConnector: std::fmt::Debug + Send + Sync + 'static { + /// Create a new [`HttpClient`] with the provided [`ClientOptions`] + fn connect(&self, options: &ClientOptions) -> crate::Result; +} + +/// [`HttpConnector`] using [`reqwest::Client`] +#[derive(Debug, Default)] +#[allow(missing_copy_implementations)] +#[cfg(not(target_arch = "wasm32"))] +pub struct ReqwestConnector {} + +#[cfg(not(target_arch = "wasm32"))] +impl HttpConnector for ReqwestConnector { + fn connect(&self, options: &ClientOptions) -> crate::Result { + let client = options.client()?; + Ok(HttpClient::new(client)) + } +} + +#[cfg(target_arch = "wasm32")] +pub(crate) fn http_connector( + custom: Option>, +) -> crate::Result> { + match custom { + Some(x) => Ok(x), + None => Err(crate::Error::NotSupported { + source: "WASM32 architectures must provide an HTTPConnector" + .to_string() + .into(), + }), + } +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn http_connector( + custom: Option>, +) -> crate::Result> { + match custom { + Some(x) => Ok(x), + None => Ok(Arc::new(ReqwestConnector {})), + } +} diff --git a/src/client/dns.rs b/src/client/dns.rs new file mode 100644 index 0000000..51df926 --- /dev/null +++ b/src/client/dns.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::net::ToSocketAddrs; + +use rand::prelude::SliceRandom; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; +use tokio::task::JoinSet; + +type DynErr = Box; + +#[derive(Debug)] +pub(crate) struct ShuffleResolver; + +impl Resolve for ShuffleResolver { + fn resolve(&self, name: Name) -> Resolving { + Box::pin(async move { + // use `JoinSet` to propagate cancelation + let mut tasks = JoinSet::new(); + tasks.spawn_blocking(move || { + let it = (name.as_str(), 0).to_socket_addrs()?; + let mut addrs = it.collect::>(); + + addrs.shuffle(&mut rand::thread_rng()); + + Ok(Box::new(addrs.into_iter()) as Addrs) + }); + + tasks + .join_next() + .await + .expect("spawned on task") + .map_err(|err| Box::new(err) as DynErr)? + }) + } +} diff --git a/src/client/get.rs b/src/client/get.rs new file mode 100644 index 0000000..4c65c6d --- /dev/null +++ b/src/client/get.rs @@ -0,0 +1,429 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Range; + +use crate::client::header::{header_meta, HeaderConfig}; +use crate::client::HttpResponse; +use crate::path::Path; +use crate::{Attribute, Attributes, GetOptions, GetRange, GetResult, GetResultPayload, Result}; +use async_trait::async_trait; +use futures::{StreamExt, TryStreamExt}; +use http::header::{ + CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_RANGE, + CONTENT_TYPE, +}; +use http::StatusCode; +use reqwest::header::ToStrError; + +/// A client that can perform a get request +#[async_trait] +pub(crate) trait GetClient: Send + Sync + 'static { + const STORE: &'static str; + + /// Configure the [`HeaderConfig`] for this client + const HEADER_CONFIG: HeaderConfig; + + async fn get_request(&self, path: &Path, options: GetOptions) -> Result; +} + +/// Extension trait for [`GetClient`] that adds common retrieval functionality +#[async_trait] +pub(crate) trait GetClientExt { + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result; +} + +#[async_trait] +impl GetClientExt for T { + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + let range = options.range.clone(); + if let Some(r) = range.as_ref() { + r.is_valid().map_err(|e| crate::Error::Generic { + store: T::STORE, + source: Box::new(e), + })?; + } + let response = self.get_request(location, options).await?; + get_result::(location, range, response).map_err(|e| crate::Error::Generic { + store: T::STORE, + source: Box::new(e), + }) + } +} + +struct ContentRange { + /// The range of the object returned + range: Range, + /// The total size of the object being requested + size: u64, +} + +impl ContentRange { + /// Parse a content range of the form `bytes -/` + /// + /// + fn from_str(s: &str) -> Option { + let rem = s.trim().strip_prefix("bytes ")?; + let (range, size) = rem.split_once('/')?; + let size = size.parse().ok()?; + + let (start_s, end_s) = range.split_once('-')?; + + let start = start_s.parse().ok()?; + let end: u64 = end_s.parse().ok()?; + + Some(Self { + size, + range: start..end + 1, + }) + } +} + +/// A specialized `Error` for get-related errors +#[derive(Debug, thiserror::Error)] +enum GetResultError { + #[error(transparent)] + Header { + #[from] + source: crate::client::header::Error, + }, + + #[error(transparent)] + InvalidRangeRequest { + #[from] + source: crate::util::InvalidGetRange, + }, + + #[error("Received non-partial response when range requested")] + NotPartial, + + #[error("Content-Range header not present in partial response")] + NoContentRange, + + #[error("Failed to parse value for CONTENT_RANGE header: \"{value}\"")] + ParseContentRange { value: String }, + + #[error("Content-Range header contained non UTF-8 characters")] + InvalidContentRange { source: ToStrError }, + + #[error("Cache-Control header contained non UTF-8 characters")] + InvalidCacheControl { source: ToStrError }, + + #[error("Content-Disposition header contained non UTF-8 characters")] + InvalidContentDisposition { source: ToStrError }, + + #[error("Content-Encoding header contained non UTF-8 characters")] + InvalidContentEncoding { source: ToStrError }, + + #[error("Content-Language header contained non UTF-8 characters")] + InvalidContentLanguage { source: ToStrError }, + + #[error("Content-Type header contained non UTF-8 characters")] + InvalidContentType { source: ToStrError }, + + #[error("Metadata value for \"{key:?}\" contained non UTF-8 characters")] + InvalidMetadata { key: String }, + + #[error("Requested {expected:?}, got {actual:?}")] + UnexpectedRange { + expected: Range, + actual: Range, + }, +} + +fn get_result( + location: &Path, + range: Option, + response: HttpResponse, +) -> Result { + let mut meta = header_meta(location, response.headers(), T::HEADER_CONFIG)?; + + // ensure that we receive the range we asked for + let range = if let Some(expected) = range { + if response.status() != StatusCode::PARTIAL_CONTENT { + return Err(GetResultError::NotPartial); + } + + let val = response + .headers() + .get(CONTENT_RANGE) + .ok_or(GetResultError::NoContentRange)?; + + let value = val + .to_str() + .map_err(|source| GetResultError::InvalidContentRange { source })?; + + let value = ContentRange::from_str(value).ok_or_else(|| { + let value = value.into(); + GetResultError::ParseContentRange { value } + })?; + + let actual = value.range; + + // Update size to reflect full size of object (#5272) + meta.size = value.size; + + let expected = expected.as_range(meta.size)?; + + if actual != expected { + return Err(GetResultError::UnexpectedRange { expected, actual }); + } + + actual + } else { + 0..meta.size + }; + + macro_rules! parse_attributes { + ($headers:expr, $(($header:expr, $attr:expr, $map_err:expr)),*) => {{ + let mut attributes = Attributes::new(); + $( + if let Some(x) = $headers.get($header) { + let x = x.to_str().map_err($map_err)?; + attributes.insert($attr, x.to_string().into()); + } + )* + attributes + }} + } + + let mut attributes = parse_attributes!( + response.headers(), + (CACHE_CONTROL, Attribute::CacheControl, |source| { + GetResultError::InvalidCacheControl { source } + }), + ( + CONTENT_DISPOSITION, + Attribute::ContentDisposition, + |source| GetResultError::InvalidContentDisposition { source } + ), + (CONTENT_ENCODING, Attribute::ContentEncoding, |source| { + GetResultError::InvalidContentEncoding { source } + }), + (CONTENT_LANGUAGE, Attribute::ContentLanguage, |source| { + GetResultError::InvalidContentLanguage { source } + }), + (CONTENT_TYPE, Attribute::ContentType, |source| { + GetResultError::InvalidContentType { source } + }) + ); + + // Add attributes that match the user-defined metadata prefix (e.g. x-amz-meta-) + if let Some(prefix) = T::HEADER_CONFIG.user_defined_metadata_prefix { + for (key, val) in response.headers() { + if let Some(suffix) = key.as_str().strip_prefix(prefix) { + if let Ok(val_str) = val.to_str() { + attributes.insert( + Attribute::Metadata(suffix.to_string().into()), + val_str.to_string().into(), + ); + } else { + return Err(GetResultError::InvalidMetadata { + key: key.to_string(), + }); + } + } + } + } + + let stream = response + .into_body() + .bytes_stream() + .map_err(|source| crate::Error::Generic { + store: T::STORE, + source: Box::new(source), + }) + .boxed(); + + Ok(GetResult { + range, + meta, + attributes, + payload: GetResultPayload::Stream(stream), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use http::header::*; + + struct TestClient {} + + #[async_trait] + impl GetClient for TestClient { + const STORE: &'static str = "TEST"; + + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: false, + last_modified_required: false, + version_header: None, + user_defined_metadata_prefix: Some("x-test-meta-"), + }; + + async fn get_request(&self, _: &Path, _: GetOptions) -> Result { + unimplemented!() + } + } + + fn make_response( + object_size: usize, + range: Option>, + status: StatusCode, + content_range: Option<&str>, + headers: Option>, + ) -> HttpResponse { + let mut builder = http::Response::builder(); + if let Some(range) = content_range { + builder = builder.header(CONTENT_RANGE, range); + } + + let body = match range { + Some(range) => vec![0_u8; range.end - range.start], + None => vec![0_u8; object_size], + }; + + if let Some(headers) = headers { + for (key, value) in headers { + builder = builder.header(key, value); + } + } + + builder + .status(status) + .header(CONTENT_LENGTH, object_size) + .body(body.into()) + .unwrap() + } + + #[tokio::test] + async fn test_get_result() { + let path = Path::from("test"); + + let resp = make_response(12, None, StatusCode::OK, None, None); + let res = get_result::(&path, None, resp).unwrap(); + assert_eq!(res.meta.size, 12); + assert_eq!(res.range, 0..12); + let bytes = res.bytes().await.unwrap(); + assert_eq!(bytes.len(), 12); + + let get_range = GetRange::from(2..3); + + let resp = make_response( + 12, + Some(2..3), + StatusCode::PARTIAL_CONTENT, + Some("bytes 2-2/12"), + None, + ); + let res = get_result::(&path, Some(get_range.clone()), resp).unwrap(); + assert_eq!(res.meta.size, 12); + assert_eq!(res.range, 2..3); + let bytes = res.bytes().await.unwrap(); + assert_eq!(bytes.len(), 1); + + let resp = make_response(12, Some(2..3), StatusCode::OK, None, None); + let err = get_result::(&path, Some(get_range.clone()), resp).unwrap_err(); + assert_eq!( + err.to_string(), + "Received non-partial response when range requested" + ); + + let resp = make_response( + 12, + Some(2..3), + StatusCode::PARTIAL_CONTENT, + Some("bytes 2-3/12"), + None, + ); + let err = get_result::(&path, Some(get_range.clone()), resp).unwrap_err(); + assert_eq!(err.to_string(), "Requested 2..3, got 2..4"); + + let resp = make_response( + 12, + Some(2..3), + StatusCode::PARTIAL_CONTENT, + Some("bytes 2-2/*"), + None, + ); + let err = get_result::(&path, Some(get_range.clone()), resp).unwrap_err(); + assert_eq!( + err.to_string(), + "Failed to parse value for CONTENT_RANGE header: \"bytes 2-2/*\"" + ); + + let resp = make_response(12, Some(2..3), StatusCode::PARTIAL_CONTENT, None, None); + let err = get_result::(&path, Some(get_range.clone()), resp).unwrap_err(); + assert_eq!( + err.to_string(), + "Content-Range header not present in partial response" + ); + + let resp = make_response( + 2, + Some(2..3), + StatusCode::PARTIAL_CONTENT, + Some("bytes 2-3/2"), + None, + ); + let err = get_result::(&path, Some(get_range.clone()), resp).unwrap_err(); + assert_eq!( + err.to_string(), + "Wanted range starting at 2, but object was only 2 bytes long" + ); + + let resp = make_response( + 6, + Some(2..6), + StatusCode::PARTIAL_CONTENT, + Some("bytes 2-5/6"), + None, + ); + let res = get_result::(&path, Some(GetRange::Suffix(4)), resp).unwrap(); + assert_eq!(res.meta.size, 6); + assert_eq!(res.range, 2..6); + let bytes = res.bytes().await.unwrap(); + assert_eq!(bytes.len(), 4); + + let resp = make_response( + 6, + Some(2..6), + StatusCode::PARTIAL_CONTENT, + Some("bytes 2-3/6"), + None, + ); + let err = get_result::(&path, Some(GetRange::Suffix(4)), resp).unwrap_err(); + assert_eq!(err.to_string(), "Requested 2..6, got 2..4"); + + let resp = make_response( + 12, + None, + StatusCode::OK, + None, + Some(vec![("x-test-meta-foo", "bar")]), + ); + let res = get_result::(&path, None, resp).unwrap(); + assert_eq!(res.meta.size, 12); + assert_eq!(res.range, 0..12); + assert_eq!( + res.attributes.get(&Attribute::Metadata("foo".into())), + Some(&"bar".into()) + ); + let bytes = res.bytes().await.unwrap(); + assert_eq!(bytes.len(), 12); + } +} diff --git a/src/client/header.rs b/src/client/header.rs new file mode 100644 index 0000000..d7e14b3 --- /dev/null +++ b/src/client/header.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic for extracting ObjectMeta from headers used by AWS, GCP and Azure + +use crate::path::Path; +use crate::ObjectMeta; +use chrono::{DateTime, TimeZone, Utc}; +use http::header::{CONTENT_LENGTH, ETAG, LAST_MODIFIED}; +use http::HeaderMap; + +#[derive(Debug, Copy, Clone)] +/// Configuration for header extraction +pub(crate) struct HeaderConfig { + /// Whether to require an ETag header when extracting [`ObjectMeta`] from headers. + /// + /// Defaults to `true` + pub etag_required: bool, + + /// Whether to require a Last-Modified header when extracting [`ObjectMeta`] from headers. + /// + /// Defaults to `true` + pub last_modified_required: bool, + + /// The version header name if any + pub version_header: Option<&'static str>, + + /// The user defined metadata prefix if any + pub user_defined_metadata_prefix: Option<&'static str>, +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum Error { + #[error("ETag Header missing from response")] + MissingEtag, + + #[error("Received header containing non-ASCII data")] + BadHeader { source: reqwest::header::ToStrError }, + + #[error("Last-Modified Header missing from response")] + MissingLastModified, + + #[error("Content-Length Header missing from response")] + MissingContentLength, + + #[error("Invalid last modified '{}': {}", last_modified, source)] + InvalidLastModified { + last_modified: String, + source: chrono::ParseError, + }, + + #[error("Invalid content length '{}': {}", content_length, source)] + InvalidContentLength { + content_length: String, + source: std::num::ParseIntError, + }, +} + +/// Extracts a PutResult from the provided [`HeaderMap`] +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub(crate) fn get_put_result( + headers: &HeaderMap, + version: &str, +) -> Result { + let e_tag = Some(get_etag(headers)?); + let version = get_version(headers, version)?; + Ok(crate::PutResult { e_tag, version }) +} + +/// Extracts a optional version from the provided [`HeaderMap`] +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub(crate) fn get_version(headers: &HeaderMap, version: &str) -> Result, Error> { + Ok(match headers.get(version) { + Some(x) => Some( + x.to_str() + .map_err(|source| Error::BadHeader { source })? + .to_string(), + ), + None => None, + }) +} + +/// Extracts an etag from the provided [`HeaderMap`] +pub(crate) fn get_etag(headers: &HeaderMap) -> Result { + let e_tag = headers.get(ETAG).ok_or(Error::MissingEtag)?; + Ok(e_tag + .to_str() + .map_err(|source| Error::BadHeader { source })? + .to_string()) +} + +/// Extracts [`ObjectMeta`] from the provided [`HeaderMap`] +pub(crate) fn header_meta( + location: &Path, + headers: &HeaderMap, + cfg: HeaderConfig, +) -> Result { + let last_modified = match headers.get(LAST_MODIFIED) { + Some(last_modified) => { + let last_modified = last_modified + .to_str() + .map_err(|source| Error::BadHeader { source })?; + + DateTime::parse_from_rfc2822(last_modified) + .map_err(|source| Error::InvalidLastModified { + last_modified: last_modified.into(), + source, + })? + .with_timezone(&Utc) + } + None if cfg.last_modified_required => return Err(Error::MissingLastModified), + None => Utc.timestamp_nanos(0), + }; + + let e_tag = match get_etag(headers) { + Ok(e_tag) => Some(e_tag), + Err(Error::MissingEtag) if !cfg.etag_required => None, + Err(e) => return Err(e), + }; + + let content_length = headers + .get(CONTENT_LENGTH) + .ok_or(Error::MissingContentLength)?; + + let content_length = content_length + .to_str() + .map_err(|source| Error::BadHeader { source })?; + + let size = content_length + .parse() + .map_err(|source| Error::InvalidContentLength { + content_length: content_length.into(), + source, + })?; + + let version = match cfg.version_header.and_then(|h| headers.get(h)) { + Some(v) => Some( + v.to_str() + .map_err(|source| Error::BadHeader { source })? + .to_string(), + ), + None => None, + }; + + Ok(ObjectMeta { + location: location.clone(), + last_modified, + version, + size, + e_tag, + }) +} diff --git a/src/client/list.rs b/src/client/list.rs new file mode 100644 index 0000000..fe9bfeb --- /dev/null +++ b/src/client/list.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::pagination::stream_paginated; +use crate::path::Path; +use crate::Result; +use crate::{ListResult, ObjectMeta}; +use async_trait::async_trait; +use futures::stream::BoxStream; +use futures::{StreamExt, TryStreamExt}; +use std::collections::BTreeSet; + +/// A client that can perform paginated list requests +#[async_trait] +pub(crate) trait ListClient: Send + Sync + 'static { + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + token: Option<&str>, + offset: Option<&str>, + ) -> Result<(ListResult, Option)>; +} + +/// Extension trait for [`ListClient`] that adds common listing functionality +#[async_trait] +pub(crate) trait ListClientExt { + fn list_paginated( + &self, + prefix: Option<&Path>, + delimiter: bool, + offset: Option<&Path>, + ) -> BoxStream<'static, Result>; + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result>; + + #[allow(unused)] + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result>; + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result; +} + +#[async_trait] +impl ListClientExt for T { + fn list_paginated( + &self, + prefix: Option<&Path>, + delimiter: bool, + offset: Option<&Path>, + ) -> BoxStream<'static, Result> { + let offset = offset.map(|x| x.to_string()); + let prefix = prefix + .filter(|x| !x.as_ref().is_empty()) + .map(|p| format!("{}{}", p.as_ref(), crate::path::DELIMITER)); + + stream_paginated( + self.clone(), + (prefix, offset), + move |client, (prefix, offset), token| async move { + let (r, next_token) = client + .list_request( + prefix.as_deref(), + delimiter, + token.as_deref(), + offset.as_deref(), + ) + .await?; + Ok((r, (prefix, offset), next_token)) + }, + ) + .boxed() + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.list_paginated(prefix, false, None) + .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) + .try_flatten() + .boxed() + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + self.list_paginated(prefix, false, Some(offset)) + .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) + .try_flatten() + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let mut stream = self.list_paginated(prefix, true, None); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + while let Some(result) = stream.next().await { + let response = result?; + common_prefixes.extend(response.common_prefixes.into_iter()); + objects.extend(response.objects.into_iter()); + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + } +} diff --git a/src/client/mock_server.rs b/src/client/mock_server.rs new file mode 100644 index 0000000..8be4a72 --- /dev/null +++ b/src/client/mock_server.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use futures::future::BoxFuture; +use futures::FutureExt; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; +use parking_lot::Mutex; +use std::collections::VecDeque; +use std::convert::Infallible; +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio::sync::oneshot; +use tokio::task::{JoinHandle, JoinSet}; + +pub(crate) type ResponseFn = + Box) -> BoxFuture<'static, Response> + Send>; + +/// A mock server +pub(crate) struct MockServer { + responses: Arc>>, + shutdown: oneshot::Sender<()>, + handle: JoinHandle<()>, + url: String, +} + +impl MockServer { + pub(crate) async fn new() -> Self { + let responses: Arc>> = + Arc::new(Mutex::new(VecDeque::with_capacity(10))); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + + let (shutdown, mut rx) = oneshot::channel::<()>(); + + let url = format!("http://{}", listener.local_addr().unwrap()); + + let r = Arc::clone(&responses); + let handle = tokio::spawn(async move { + let mut set = JoinSet::new(); + + loop { + let (stream, _) = tokio::select! { + conn = listener.accept() => conn.unwrap(), + _ = &mut rx => break, + }; + + let r = Arc::clone(&r); + set.spawn(async move { + let _ = http1::Builder::new() + .serve_connection( + TokioIo::new(stream), + service_fn(move |req| { + let r = Arc::clone(&r); + let next = r.lock().pop_front(); + async move { + Ok::<_, Infallible>(match next { + Some(r) => r(req).await, + None => Response::new("Hello World".to_string()), + }) + } + }), + ) + .await; + }); + } + + set.abort_all(); + }); + + Self { + responses, + shutdown, + handle, + url, + } + } + + /// The url of the mock server + pub(crate) fn url(&self) -> &str { + &self.url + } + + /// Add a response + pub(crate) fn push(&self, response: Response) { + self.push_fn(|_| response) + } + + /// Add a response function + pub(crate) fn push_fn(&self, f: F) + where + F: FnOnce(Request) -> Response + Send + 'static, + { + let f = Box::new(|req| async move { f(req) }.boxed()); + self.responses.lock().push_back(f) + } + + pub(crate) fn push_async_fn(&self, f: F) + where + F: FnOnce(Request) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + self.responses.lock().push_back(Box::new(|r| f(r).boxed())) + } + + /// Shutdown the mock server + pub(crate) async fn shutdown(self) { + let _ = self.shutdown.send(()); + self.handle.await.unwrap() + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..bd0347b --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,1012 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic utilities reqwest based ObjectStore implementations + +pub(crate) mod backoff; + +#[cfg(not(target_arch = "wasm32"))] +mod dns; + +#[cfg(test)] +pub(crate) mod mock_server; + +pub(crate) mod retry; + +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub(crate) mod pagination; + +pub(crate) mod get; + +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub(crate) mod list; + +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub(crate) mod token; + +pub(crate) mod header; + +#[cfg(any(feature = "aws", feature = "gcp"))] +pub(crate) mod s3; + +mod body; +pub use body::{HttpRequest, HttpRequestBody, HttpResponse, HttpResponseBody}; + +pub(crate) mod builder; + +mod connection; +pub(crate) use connection::http_connector; +#[cfg(not(target_arch = "wasm32"))] +pub use connection::ReqwestConnector; +pub use connection::{HttpClient, HttpConnector, HttpError, HttpErrorKind, HttpService}; + +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub(crate) mod parts; + +use async_trait::async_trait; +use reqwest::header::{HeaderMap, HeaderValue}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; + +#[cfg(not(target_arch = "wasm32"))] +use reqwest::{NoProxy, Proxy}; + +use crate::config::{fmt_duration, ConfigValue}; +use crate::path::Path; +use crate::{GetOptions, Result}; + +fn map_client_error(e: reqwest::Error) -> super::Error { + super::Error::Generic { + store: "HTTP client", + source: Box::new(e), + } +} + +static DEFAULT_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); + +/// Configuration keys for [`ClientOptions`] +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)] +#[non_exhaustive] +pub enum ClientConfigKey { + /// Allow non-TLS, i.e. non-HTTPS connections + AllowHttp, + /// Skip certificate validation on https connections. + /// + /// # Warning + /// + /// You should think very carefully before using this method. If + /// invalid certificates are trusted, *any* certificate for *any* site + /// will be trusted for use. This includes expired certificates. This + /// introduces significant vulnerabilities, and should only be used + /// as a last resort or for testing + AllowInvalidCertificates, + /// Timeout for only the connect phase of a Client + ConnectTimeout, + /// default CONTENT_TYPE for uploads + DefaultContentType, + /// Only use http1 connections + Http1Only, + /// Interval for HTTP2 Ping frames should be sent to keep a connection alive. + Http2KeepAliveInterval, + /// Timeout for receiving an acknowledgement of the keep-alive ping. + Http2KeepAliveTimeout, + /// Enable HTTP2 keep alive pings for idle connections + Http2KeepAliveWhileIdle, + /// Sets the maximum frame size to use for HTTP2. + Http2MaxFrameSize, + /// Only use http2 connections + Http2Only, + /// The pool max idle timeout + /// + /// This is the length of time an idle connection will be kept alive + PoolIdleTimeout, + /// maximum number of idle connections per host + PoolMaxIdlePerHost, + /// HTTP proxy to use for requests + ProxyUrl, + /// PEM-formatted CA certificate for proxy connections + ProxyCaCertificate, + /// List of hosts that bypass proxy + ProxyExcludes, + /// Randomize order addresses that the DNS resolution yields. + /// + /// This will spread the connections accross more servers. + RandomizeAddresses, + /// Request timeout + /// + /// The timeout is applied from when the request starts connecting until the + /// response body has finished + Timeout, + /// User-Agent header to be used by this client + UserAgent, +} + +impl AsRef for ClientConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::AllowHttp => "allow_http", + Self::AllowInvalidCertificates => "allow_invalid_certificates", + Self::ConnectTimeout => "connect_timeout", + Self::DefaultContentType => "default_content_type", + Self::Http1Only => "http1_only", + Self::Http2Only => "http2_only", + Self::Http2KeepAliveInterval => "http2_keep_alive_interval", + Self::Http2KeepAliveTimeout => "http2_keep_alive_timeout", + Self::Http2KeepAliveWhileIdle => "http2_keep_alive_while_idle", + Self::Http2MaxFrameSize => "http2_max_frame_size", + Self::PoolIdleTimeout => "pool_idle_timeout", + Self::PoolMaxIdlePerHost => "pool_max_idle_per_host", + Self::ProxyUrl => "proxy_url", + Self::ProxyCaCertificate => "proxy_ca_certificate", + Self::ProxyExcludes => "proxy_excludes", + Self::RandomizeAddresses => "randomize_addresses", + Self::Timeout => "timeout", + Self::UserAgent => "user_agent", + } + } +} + +impl FromStr for ClientConfigKey { + type Err = super::Error; + + fn from_str(s: &str) -> Result { + match s { + "allow_http" => Ok(Self::AllowHttp), + "allow_invalid_certificates" => Ok(Self::AllowInvalidCertificates), + "connect_timeout" => Ok(Self::ConnectTimeout), + "default_content_type" => Ok(Self::DefaultContentType), + "http1_only" => Ok(Self::Http1Only), + "http2_only" => Ok(Self::Http2Only), + "http2_keep_alive_interval" => Ok(Self::Http2KeepAliveInterval), + "http2_keep_alive_timeout" => Ok(Self::Http2KeepAliveTimeout), + "http2_keep_alive_while_idle" => Ok(Self::Http2KeepAliveWhileIdle), + "http2_max_frame_size" => Ok(Self::Http2MaxFrameSize), + "pool_idle_timeout" => Ok(Self::PoolIdleTimeout), + "pool_max_idle_per_host" => Ok(Self::PoolMaxIdlePerHost), + "proxy_url" => Ok(Self::ProxyUrl), + "proxy_ca_certificate" => Ok(Self::ProxyCaCertificate), + "proxy_excludes" => Ok(Self::ProxyExcludes), + "randomize_addresses" => Ok(Self::RandomizeAddresses), + "timeout" => Ok(Self::Timeout), + "user_agent" => Ok(Self::UserAgent), + _ => Err(super::Error::UnknownConfigurationKey { + store: "HTTP", + key: s.into(), + }), + } + } +} + +/// Represents a CA certificate provided by the user. +/// +/// This is used to configure the client to trust a specific certificate. See +/// [Self::from_pem] for an example +#[derive(Debug, Clone)] +#[cfg(not(target_arch = "wasm32"))] +pub struct Certificate(reqwest::tls::Certificate); + +#[cfg(not(target_arch = "wasm32"))] +impl Certificate { + /// Create a `Certificate` from a PEM encoded certificate. + /// + /// # Example from a PEM file + /// + /// ```no_run + /// # use object_store::Certificate; + /// # use std::fs::File; + /// # use std::io::Read; + /// let mut buf = Vec::new(); + /// File::open("my_cert.pem").unwrap() + /// .read_to_end(&mut buf).unwrap(); + /// let cert = Certificate::from_pem(&buf).unwrap(); + /// + /// ``` + pub fn from_pem(pem: &[u8]) -> Result { + Ok(Self( + reqwest::tls::Certificate::from_pem(pem).map_err(map_client_error)?, + )) + } + + /// Create a collection of `Certificate` from a PEM encoded certificate + /// bundle. + /// + /// Files that contain such collections have extensions such as `.crt`, + /// `.cer` and `.pem` files. + pub fn from_pem_bundle(pem_bundle: &[u8]) -> Result> { + Ok(reqwest::tls::Certificate::from_pem_bundle(pem_bundle) + .map_err(map_client_error)? + .into_iter() + .map(Self) + .collect()) + } + + /// Create a `Certificate` from a binary DER encoded certificate. + pub fn from_der(der: &[u8]) -> Result { + Ok(Self( + reqwest::tls::Certificate::from_der(der).map_err(map_client_error)?, + )) + } +} + +/// HTTP client configuration for remote object stores +#[derive(Debug, Clone)] +pub struct ClientOptions { + user_agent: Option>, + #[cfg(not(target_arch = "wasm32"))] + root_certificates: Vec, + content_type_map: HashMap, + default_content_type: Option, + default_headers: Option, + proxy_url: Option, + proxy_ca_certificate: Option, + proxy_excludes: Option, + allow_http: ConfigValue, + allow_insecure: ConfigValue, + timeout: Option>, + connect_timeout: Option>, + pool_idle_timeout: Option>, + pool_max_idle_per_host: Option>, + http2_keep_alive_interval: Option>, + http2_keep_alive_timeout: Option>, + http2_keep_alive_while_idle: ConfigValue, + http2_max_frame_size: Option>, + http1_only: ConfigValue, + http2_only: ConfigValue, + randomize_addresses: ConfigValue, +} + +impl Default for ClientOptions { + fn default() -> Self { + // Defaults based on + // + // + // Which recommend a connection timeout of 3.1s and a request timeout of 2s + // + // As object store requests may involve the transfer of non-trivial volumes of data + // we opt for a slightly higher default timeout of 30 seconds + Self { + user_agent: None, + #[cfg(not(target_arch = "wasm32"))] + root_certificates: Default::default(), + content_type_map: Default::default(), + default_content_type: None, + default_headers: None, + proxy_url: None, + proxy_ca_certificate: None, + proxy_excludes: None, + allow_http: Default::default(), + allow_insecure: Default::default(), + timeout: Some(Duration::from_secs(30).into()), + connect_timeout: Some(Duration::from_secs(5).into()), + pool_idle_timeout: None, + pool_max_idle_per_host: None, + http2_keep_alive_interval: None, + http2_keep_alive_timeout: None, + http2_keep_alive_while_idle: Default::default(), + http2_max_frame_size: None, + // HTTP2 is known to be significantly slower than HTTP1, so we default + // to HTTP1 for now. + // https://github.com/apache/arrow-rs/issues/5194 + http1_only: true.into(), + http2_only: Default::default(), + randomize_addresses: true.into(), + } + } +} + +impl ClientOptions { + /// Create a new [`ClientOptions`] with default values + pub fn new() -> Self { + Default::default() + } + + /// Set an option by key + pub fn with_config(mut self, key: ClientConfigKey, value: impl Into) -> Self { + match key { + ClientConfigKey::AllowHttp => self.allow_http.parse(value), + ClientConfigKey::AllowInvalidCertificates => self.allow_insecure.parse(value), + ClientConfigKey::ConnectTimeout => { + self.connect_timeout = Some(ConfigValue::Deferred(value.into())) + } + ClientConfigKey::DefaultContentType => self.default_content_type = Some(value.into()), + ClientConfigKey::Http1Only => self.http1_only.parse(value), + ClientConfigKey::Http2Only => self.http2_only.parse(value), + ClientConfigKey::Http2KeepAliveInterval => { + self.http2_keep_alive_interval = Some(ConfigValue::Deferred(value.into())) + } + ClientConfigKey::Http2KeepAliveTimeout => { + self.http2_keep_alive_timeout = Some(ConfigValue::Deferred(value.into())) + } + ClientConfigKey::Http2KeepAliveWhileIdle => { + self.http2_keep_alive_while_idle.parse(value) + } + ClientConfigKey::Http2MaxFrameSize => { + self.http2_max_frame_size = Some(ConfigValue::Deferred(value.into())) + } + ClientConfigKey::PoolIdleTimeout => { + self.pool_idle_timeout = Some(ConfigValue::Deferred(value.into())) + } + ClientConfigKey::PoolMaxIdlePerHost => { + self.pool_max_idle_per_host = Some(ConfigValue::Deferred(value.into())) + } + ClientConfigKey::ProxyUrl => self.proxy_url = Some(value.into()), + ClientConfigKey::ProxyCaCertificate => self.proxy_ca_certificate = Some(value.into()), + ClientConfigKey::ProxyExcludes => self.proxy_excludes = Some(value.into()), + ClientConfigKey::RandomizeAddresses => { + self.randomize_addresses.parse(value); + } + ClientConfigKey::Timeout => self.timeout = Some(ConfigValue::Deferred(value.into())), + ClientConfigKey::UserAgent => { + self.user_agent = Some(ConfigValue::Deferred(value.into())) + } + } + self + } + + /// Get an option by key + pub fn get_config_value(&self, key: &ClientConfigKey) -> Option { + match key { + ClientConfigKey::AllowHttp => Some(self.allow_http.to_string()), + ClientConfigKey::AllowInvalidCertificates => Some(self.allow_insecure.to_string()), + ClientConfigKey::ConnectTimeout => self.connect_timeout.as_ref().map(fmt_duration), + ClientConfigKey::DefaultContentType => self.default_content_type.clone(), + ClientConfigKey::Http1Only => Some(self.http1_only.to_string()), + ClientConfigKey::Http2KeepAliveInterval => { + self.http2_keep_alive_interval.as_ref().map(fmt_duration) + } + ClientConfigKey::Http2KeepAliveTimeout => { + self.http2_keep_alive_timeout.as_ref().map(fmt_duration) + } + ClientConfigKey::Http2KeepAliveWhileIdle => { + Some(self.http2_keep_alive_while_idle.to_string()) + } + ClientConfigKey::Http2MaxFrameSize => { + self.http2_max_frame_size.as_ref().map(|v| v.to_string()) + } + ClientConfigKey::Http2Only => Some(self.http2_only.to_string()), + ClientConfigKey::PoolIdleTimeout => self.pool_idle_timeout.as_ref().map(fmt_duration), + ClientConfigKey::PoolMaxIdlePerHost => { + self.pool_max_idle_per_host.as_ref().map(|v| v.to_string()) + } + ClientConfigKey::ProxyUrl => self.proxy_url.clone(), + ClientConfigKey::ProxyCaCertificate => self.proxy_ca_certificate.clone(), + ClientConfigKey::ProxyExcludes => self.proxy_excludes.clone(), + ClientConfigKey::RandomizeAddresses => Some(self.randomize_addresses.to_string()), + ClientConfigKey::Timeout => self.timeout.as_ref().map(fmt_duration), + ClientConfigKey::UserAgent => self + .user_agent + .as_ref() + .and_then(|v| v.get().ok()) + .and_then(|v| v.to_str().ok().map(|s| s.to_string())), + } + } + + /// Sets the User-Agent header to be used by this client + /// + /// Default is based on the version of this crate + pub fn with_user_agent(mut self, agent: HeaderValue) -> Self { + self.user_agent = Some(agent.into()); + self + } + + /// Add a custom root certificate. + /// + /// This can be used to connect to a server that has a self-signed + /// certificate for example. + #[cfg(not(target_arch = "wasm32"))] + pub fn with_root_certificate(mut self, certificate: Certificate) -> Self { + self.root_certificates.push(certificate); + self + } + + /// Set the default CONTENT_TYPE for uploads + pub fn with_default_content_type(mut self, mime: impl Into) -> Self { + self.default_content_type = Some(mime.into()); + self + } + + /// Set the CONTENT_TYPE for a given file extension + pub fn with_content_type_for_suffix( + mut self, + extension: impl Into, + mime: impl Into, + ) -> Self { + self.content_type_map.insert(extension.into(), mime.into()); + self + } + + /// Sets the default headers for every request + pub fn with_default_headers(mut self, headers: HeaderMap) -> Self { + self.default_headers = Some(headers); + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.allow_http = allow_http.into(); + self + } + /// Allows connections to invalid SSL certificates + /// * false (default): Only valid HTTPS certificates are allowed + /// * true: All HTTPS certificates are allowed + /// + /// # Warning + /// + /// You should think very carefully before using this method. If + /// invalid certificates are trusted, *any* certificate for *any* site + /// will be trusted for use. This includes expired certificates. This + /// introduces significant vulnerabilities, and should only be used + /// as a last resort or for testing + pub fn with_allow_invalid_certificates(mut self, allow_insecure: bool) -> Self { + self.allow_insecure = allow_insecure.into(); + self + } + + /// Only use http1 connections + /// + /// This is on by default, since http2 is known to be significantly slower than http1. + pub fn with_http1_only(mut self) -> Self { + self.http2_only = false.into(); + self.http1_only = true.into(); + self + } + + /// Only use http2 connections + pub fn with_http2_only(mut self) -> Self { + self.http1_only = false.into(); + self.http2_only = true.into(); + self + } + + /// Use http2 if supported, otherwise use http1. + pub fn with_allow_http2(mut self) -> Self { + self.http1_only = false.into(); + self.http2_only = false.into(); + self + } + + /// Set a proxy URL to use for requests + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.proxy_url = Some(proxy_url.into()); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate(mut self, proxy_ca_certificate: impl Into) -> Self { + self.proxy_ca_certificate = Some(proxy_ca_certificate.into()); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.proxy_excludes = Some(proxy_excludes.into()); + self + } + + /// Set a request timeout + /// + /// The timeout is applied from when the request starts connecting until the + /// response body has finished + /// + /// Default is 30 seconds + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(ConfigValue::Parsed(timeout)); + self + } + + /// Disables the request timeout + /// + /// See [`Self::with_timeout`] + pub fn with_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + + /// Set a timeout for only the connect phase of a Client + /// + /// Default is 5 seconds + pub fn with_connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = Some(ConfigValue::Parsed(timeout)); + self + } + + /// Disables the connection timeout + /// + /// See [`Self::with_connect_timeout`] + pub fn with_connect_timeout_disabled(mut self) -> Self { + self.connect_timeout = None; + self + } + + /// Set the pool max idle timeout + /// + /// This is the length of time an idle connection will be kept alive + /// + /// Default is 90 seconds enforced by reqwest + pub fn with_pool_idle_timeout(mut self, timeout: Duration) -> Self { + self.pool_idle_timeout = Some(ConfigValue::Parsed(timeout)); + self + } + + /// Set the maximum number of idle connections per host + /// + /// Default is no limit enforced by reqwest + pub fn with_pool_max_idle_per_host(mut self, max: usize) -> Self { + self.pool_max_idle_per_host = Some(max.into()); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a connection alive. + /// + /// Default is disabled enforced by reqwest + pub fn with_http2_keep_alive_interval(mut self, interval: Duration) -> Self { + self.http2_keep_alive_interval = Some(ConfigValue::Parsed(interval)); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will be closed. + /// Does nothing if http2_keep_alive_interval is disabled. + /// + /// Default is disabled enforced by reqwest + pub fn with_http2_keep_alive_timeout(mut self, interval: Duration) -> Self { + self.http2_keep_alive_timeout = Some(ConfigValue::Parsed(interval)); + self + } + + /// Enable HTTP2 keep alive pings for idle connections + /// + /// If disabled, keep-alive pings are only sent while there are open request/response + /// streams. If enabled, pings are also sent when no streams are active + /// + /// Default is disabled enforced by reqwest + pub fn with_http2_keep_alive_while_idle(mut self) -> Self { + self.http2_keep_alive_while_idle = true.into(); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Default is currently 16,384 but may change internally to optimize for common uses. + pub fn with_http2_max_frame_size(mut self, sz: u32) -> Self { + self.http2_max_frame_size = Some(ConfigValue::Parsed(sz)); + self + } + + /// Get the mime type for the file in `path` to be uploaded + /// + /// Gets the file extension from `path`, and returns the + /// mime type if it was defined initially through + /// `ClientOptions::with_content_type_for_suffix` + /// + /// Otherwise, returns the default mime type if it was defined + /// earlier through `ClientOptions::with_default_content_type` + pub fn get_content_type(&self, path: &Path) -> Option<&str> { + match path.extension() { + Some(extension) => match self.content_type_map.get(extension) { + Some(ct) => Some(ct.as_str()), + None => self.default_content_type.as_deref(), + }, + None => self.default_content_type.as_deref(), + } + } + + /// Returns a copy of this [`ClientOptions`] with overrides necessary for metadata endpoint access + /// + /// In particular: + /// * Allows HTTP as metadata endpoints do not use TLS + /// * Configures a low connection timeout to provide quick feedback if not present + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) fn metadata_options(&self) -> Self { + self.clone() + .with_allow_http(true) + .with_connect_timeout(Duration::from_secs(1)) + } + + #[cfg(not(target_arch = "wasm32"))] + pub(crate) fn client(&self) -> Result { + let mut builder = reqwest::ClientBuilder::new(); + + match &self.user_agent { + Some(user_agent) => builder = builder.user_agent(user_agent.get()?), + None => builder = builder.user_agent(DEFAULT_USER_AGENT), + } + + if let Some(headers) = &self.default_headers { + builder = builder.default_headers(headers.clone()) + } + + if let Some(proxy) = &self.proxy_url { + let mut proxy = Proxy::all(proxy).map_err(map_client_error)?; + + if let Some(certificate) = &self.proxy_ca_certificate { + let certificate = reqwest::tls::Certificate::from_pem(certificate.as_bytes()) + .map_err(map_client_error)?; + + builder = builder.add_root_certificate(certificate); + } + + if let Some(proxy_excludes) = &self.proxy_excludes { + let no_proxy = NoProxy::from_string(proxy_excludes); + + proxy = proxy.no_proxy(no_proxy); + } + + builder = builder.proxy(proxy); + } + + for certificate in &self.root_certificates { + builder = builder.add_root_certificate(certificate.0.clone()); + } + + if let Some(timeout) = &self.timeout { + builder = builder.timeout(timeout.get()?) + } + + if let Some(timeout) = &self.connect_timeout { + builder = builder.connect_timeout(timeout.get()?) + } + + if let Some(timeout) = &self.pool_idle_timeout { + builder = builder.pool_idle_timeout(timeout.get()?) + } + + if let Some(max) = &self.pool_max_idle_per_host { + builder = builder.pool_max_idle_per_host(max.get()?) + } + + if let Some(interval) = &self.http2_keep_alive_interval { + builder = builder.http2_keep_alive_interval(interval.get()?) + } + + if let Some(interval) = &self.http2_keep_alive_timeout { + builder = builder.http2_keep_alive_timeout(interval.get()?) + } + + if self.http2_keep_alive_while_idle.get()? { + builder = builder.http2_keep_alive_while_idle(true) + } + + if let Some(sz) = &self.http2_max_frame_size { + builder = builder.http2_max_frame_size(Some(sz.get()?)) + } + + if self.http1_only.get()? { + builder = builder.http1_only() + } + + if self.http2_only.get()? { + builder = builder.http2_prior_knowledge() + } + + if self.allow_insecure.get()? { + builder = builder.danger_accept_invalid_certs(true) + } + + // Explicitly disable compression, since it may be automatically enabled + // when certain reqwest features are enabled. Compression interferes + // with the `Content-Length` header, which is used to determine the + // size of objects. + builder = builder.no_gzip().no_brotli().no_zstd().no_deflate(); + + if self.randomize_addresses.get()? { + builder = builder.dns_resolver(Arc::new(dns::ShuffleResolver)); + } + + builder + .https_only(!self.allow_http.get()?) + .build() + .map_err(map_client_error) + } +} + +pub(crate) trait GetOptionsExt { + fn with_get_options(self, options: GetOptions) -> Self; +} + +impl GetOptionsExt for HttpRequestBuilder { + fn with_get_options(mut self, options: GetOptions) -> Self { + use hyper::header::*; + + let GetOptions { + if_match, + if_none_match, + if_modified_since, + if_unmodified_since, + range, + version: _, + head: _, + extensions, + } = options; + + if let Some(range) = range { + self = self.header(RANGE, range.to_string()); + } + + if let Some(tag) = if_match { + self = self.header(IF_MATCH, tag); + } + + if let Some(tag) = if_none_match { + self = self.header(IF_NONE_MATCH, tag); + } + + const DATE_FORMAT: &str = "%a, %d %b %Y %H:%M:%S GMT"; + if let Some(date) = if_unmodified_since { + self = self.header(IF_UNMODIFIED_SINCE, date.format(DATE_FORMAT).to_string()); + } + + if let Some(date) = if_modified_since { + self = self.header(IF_MODIFIED_SINCE, date.format(DATE_FORMAT).to_string()); + } + + self = self.extensions(extensions); + + self + } +} + +/// Provides credentials for use when signing requests +#[async_trait] +pub trait CredentialProvider: std::fmt::Debug + Send + Sync { + /// The type of credential returned by this provider + type Credential; + + /// Return a credential + async fn get_credential(&self) -> Result>; +} + +/// A static set of credentials +#[derive(Debug)] +pub struct StaticCredentialProvider { + credential: Arc, +} + +impl StaticCredentialProvider { + /// A [`CredentialProvider`] for a static credential of type `T` + pub fn new(credential: T) -> Self { + Self { + credential: Arc::new(credential), + } + } +} + +#[async_trait] +impl CredentialProvider for StaticCredentialProvider +where + T: std::fmt::Debug + Send + Sync, +{ + type Credential = T; + + async fn get_credential(&self) -> Result> { + Ok(Arc::clone(&self.credential)) + } +} + +#[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))] +mod cloud { + use super::*; + use crate::client::token::{TemporaryToken, TokenCache}; + use crate::RetryConfig; + + /// A [`CredentialProvider`] that uses [`HttpClient`] to fetch temporary tokens + #[derive(Debug)] + pub(crate) struct TokenCredentialProvider { + inner: T, + client: HttpClient, + retry: RetryConfig, + cache: TokenCache>, + } + + impl TokenCredentialProvider { + pub(crate) fn new(inner: T, client: HttpClient, retry: RetryConfig) -> Self { + Self { + inner, + client, + retry, + cache: Default::default(), + } + } + + /// Override the minimum remaining TTL for a cached token to be used + #[cfg(any(feature = "aws", feature = "gcp"))] + pub(crate) fn with_min_ttl(mut self, min_ttl: Duration) -> Self { + self.cache = self.cache.with_min_ttl(min_ttl); + self + } + } + + #[async_trait] + impl CredentialProvider for TokenCredentialProvider { + type Credential = T::Credential; + + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| self.inner.fetch_token(&self.client, &self.retry)) + .await + } + } + + #[async_trait] + pub(crate) trait TokenProvider: std::fmt::Debug + Send + Sync { + type Credential: std::fmt::Debug + Send + Sync; + + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> Result>>; + } +} + +use crate::client::builder::HttpRequestBuilder; +#[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))] +pub(crate) use cloud::*; + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn client_test_config_from_map() { + let allow_http = "true".to_string(); + let allow_invalid_certificates = "false".to_string(); + let connect_timeout = "90 seconds".to_string(); + let default_content_type = "object_store:fake_default_content_type".to_string(); + let http1_only = "true".to_string(); + let http2_only = "false".to_string(); + let http2_keep_alive_interval = "90 seconds".to_string(); + let http2_keep_alive_timeout = "91 seconds".to_string(); + let http2_keep_alive_while_idle = "92 seconds".to_string(); + let http2_max_frame_size = "1337".to_string(); + let pool_idle_timeout = "93 seconds".to_string(); + let pool_max_idle_per_host = "94".to_string(); + let proxy_url = "https://fake_proxy_url".to_string(); + let timeout = "95 seconds".to_string(); + let user_agent = "object_store:fake_user_agent".to_string(); + + let options = HashMap::from([ + ("allow_http", allow_http.clone()), + ( + "allow_invalid_certificates", + allow_invalid_certificates.clone(), + ), + ("connect_timeout", connect_timeout.clone()), + ("default_content_type", default_content_type.clone()), + ("http1_only", http1_only.clone()), + ("http2_only", http2_only.clone()), + ( + "http2_keep_alive_interval", + http2_keep_alive_interval.clone(), + ), + ("http2_keep_alive_timeout", http2_keep_alive_timeout.clone()), + ( + "http2_keep_alive_while_idle", + http2_keep_alive_while_idle.clone(), + ), + ("http2_max_frame_size", http2_max_frame_size.clone()), + ("pool_idle_timeout", pool_idle_timeout.clone()), + ("pool_max_idle_per_host", pool_max_idle_per_host.clone()), + ("proxy_url", proxy_url.clone()), + ("timeout", timeout.clone()), + ("user_agent", user_agent.clone()), + ]); + + let builder = options + .into_iter() + .fold(ClientOptions::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }); + + assert_eq!( + builder + .get_config_value(&ClientConfigKey::AllowHttp) + .unwrap(), + allow_http + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::AllowInvalidCertificates) + .unwrap(), + allow_invalid_certificates + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::ConnectTimeout) + .unwrap(), + connect_timeout + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::DefaultContentType) + .unwrap(), + default_content_type + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::Http1Only) + .unwrap(), + http1_only + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::Http2Only) + .unwrap(), + http2_only + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::Http2KeepAliveInterval) + .unwrap(), + http2_keep_alive_interval + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::Http2KeepAliveTimeout) + .unwrap(), + http2_keep_alive_timeout + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::Http2KeepAliveWhileIdle) + .unwrap(), + http2_keep_alive_while_idle + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::Http2MaxFrameSize) + .unwrap(), + http2_max_frame_size + ); + + assert_eq!( + builder + .get_config_value(&ClientConfigKey::PoolIdleTimeout) + .unwrap(), + pool_idle_timeout + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::PoolMaxIdlePerHost) + .unwrap(), + pool_max_idle_per_host + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::ProxyUrl) + .unwrap(), + proxy_url + ); + assert_eq!( + builder.get_config_value(&ClientConfigKey::Timeout).unwrap(), + timeout + ); + assert_eq!( + builder + .get_config_value(&ClientConfigKey::UserAgent) + .unwrap(), + user_agent + ); + } +} diff --git a/src/client/pagination.rs b/src/client/pagination.rs new file mode 100644 index 0000000..d789c74 --- /dev/null +++ b/src/client/pagination.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use futures::Stream; +use std::future::Future; + +/// Takes a paginated operation `op` that when called with: +/// +/// - A state `S` +/// - An optional next token `Option` +/// +/// Returns +/// +/// - A response value `T` +/// - The next state `S` +/// - The next continuation token `Option` +/// +/// And converts it into a `Stream>` which will first call `op(state, None)`, and yield +/// the returned response `T`. If the returned continuation token was `None` the stream will then +/// finish, otherwise it will continue to call `op(state, token)` with the values returned by the +/// previous call to `op`, until a continuation token of `None` is returned +/// +pub(crate) fn stream_paginated( + client: C, + state: S, + op: F, +) -> impl Stream> +where + C: Clone, + F: Fn(C, S, Option) -> Fut + Copy, + Fut: Future)>>, +{ + enum PaginationState { + Start(T), + HasMore(T, String), + Done, + } + + futures::stream::unfold(PaginationState::Start(state), move |state| { + let client = client.clone(); + async move { + let (s, page_token) = match state { + PaginationState::Start(s) => (s, None), + PaginationState::HasMore(s, page_token) if !page_token.is_empty() => { + (s, Some(page_token)) + } + _ => { + return None; + } + }; + + let (resp, s, continuation) = match op(client, s, page_token).await { + Ok(resp) => resp, + Err(e) => return Some((Err(e), PaginationState::Done)), + }; + + let next_state = match continuation { + Some(token) => PaginationState::HasMore(s, token), + None => PaginationState::Done, + }; + + Some((Ok(resp), next_state)) + } + }) +} diff --git a/src/client/parts.rs b/src/client/parts.rs new file mode 100644 index 0000000..9fc301e --- /dev/null +++ b/src/client/parts.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::multipart::PartId; +use parking_lot::Mutex; + +/// An interior mutable collection of upload parts and their corresponding part index +#[derive(Debug, Default)] +pub(crate) struct Parts(Mutex>); + +impl Parts { + /// Record the [`PartId`] for a given index + /// + /// Note: calling this method multiple times with the same `part_idx` + /// will result in multiple [`PartId`] in the final output + pub(crate) fn put(&self, part_idx: usize, id: PartId) { + self.0.lock().push((part_idx, id)) + } + + /// Produce the final list of [`PartId`] ordered by `part_idx` + /// + /// `expected` is the number of parts expected in the final result + pub(crate) fn finish(&self, expected: usize) -> crate::Result> { + let mut parts = self.0.lock(); + if parts.len() != expected { + return Err(crate::Error::Generic { + store: "Parts", + source: "Missing part".to_string().into(), + }); + } + parts.sort_unstable_by_key(|(idx, _)| *idx); + Ok(parts.drain(..).map(|(_, v)| v).collect()) + } +} diff --git a/src/client/retry.rs b/src/client/retry.rs new file mode 100644 index 0000000..96244aa --- /dev/null +++ b/src/client/retry.rs @@ -0,0 +1,754 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A shared HTTP client implementation incorporating retries + +use crate::client::backoff::{Backoff, BackoffConfig}; +use crate::client::builder::HttpRequestBuilder; +use crate::client::connection::HttpErrorKind; +use crate::client::{HttpClient, HttpError, HttpRequest, HttpResponse}; +use crate::PutPayload; +use futures::future::BoxFuture; +use http::{Method, Uri}; +use reqwest::header::LOCATION; +use reqwest::StatusCode; +use std::time::{Duration, Instant}; +use tracing::info; + +/// Retry request error +#[derive(Debug, thiserror::Error)] +pub struct RetryError { + method: Method, + uri: Option, + retries: usize, + max_retries: usize, + elapsed: Duration, + retry_timeout: Duration, + inner: RequestError, +} + +impl std::fmt::Display for RetryError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error performing {} ", self.method)?; + match &self.uri { + Some(uri) => write!(f, "{uri} ")?, + None => write!(f, "REDACTED ")?, + } + write!(f, "in {:?}", self.elapsed)?; + if self.retries != 0 { + write!( + f, + ", after {} retries, max_retries: {}, retry_timeout: {:?} ", + self.retries, self.max_retries, self.retry_timeout + )?; + } + write!(f, " - {}", self.inner) + } +} + +/// Context of the retry loop +struct RetryContext { + method: Method, + uri: Option, + retries: usize, + max_retries: usize, + start: Instant, + retry_timeout: Duration, +} + +impl RetryContext { + fn err(self, error: RequestError) -> RetryError { + RetryError { + uri: self.uri, + method: self.method, + retries: self.retries, + max_retries: self.max_retries, + elapsed: self.start.elapsed(), + retry_timeout: self.retry_timeout, + inner: error, + } + } + + fn exhausted(&self) -> bool { + self.retries == self.max_retries || self.start.elapsed() > self.retry_timeout + } +} + +/// The reason a request failed +#[derive(Debug, thiserror::Error)] +pub enum RequestError { + #[error("Received redirect without LOCATION, this normally indicates an incorrectly configured region" + )] + BareRedirect, + + #[error("Server returned non-2xx status code: {status}: {}", body.as_deref().unwrap_or(""))] + Status { + status: StatusCode, + body: Option, + }, + + #[error("Server returned error response: {body}")] + Response { status: StatusCode, body: String }, + + #[error(transparent)] + Http(#[from] HttpError), +} + +impl RetryError { + /// Returns the underlying [`RequestError`] + pub fn inner(&self) -> &RequestError { + &self.inner + } + + /// Returns the status code associated with this error if any + pub fn status(&self) -> Option { + match &self.inner { + RequestError::Status { status, .. } | RequestError::Response { status, .. } => { + Some(*status) + } + RequestError::BareRedirect | RequestError::Http(_) => None, + } + } + + /// Returns the error body if any + pub fn body(&self) -> Option<&str> { + match &self.inner { + RequestError::Status { body, .. } => body.as_deref(), + RequestError::Response { body, .. } => Some(body), + RequestError::BareRedirect | RequestError::Http(_) => None, + } + } + + pub fn error(self, store: &'static str, path: String) -> crate::Error { + match self.status() { + Some(StatusCode::NOT_FOUND) => crate::Error::NotFound { + path, + source: Box::new(self), + }, + Some(StatusCode::NOT_MODIFIED) => crate::Error::NotModified { + path, + source: Box::new(self), + }, + Some(StatusCode::PRECONDITION_FAILED) => crate::Error::Precondition { + path, + source: Box::new(self), + }, + Some(StatusCode::CONFLICT) => crate::Error::AlreadyExists { + path, + source: Box::new(self), + }, + Some(StatusCode::FORBIDDEN) => crate::Error::PermissionDenied { + path, + source: Box::new(self), + }, + Some(StatusCode::UNAUTHORIZED) => crate::Error::Unauthenticated { + path, + source: Box::new(self), + }, + _ => crate::Error::Generic { + store, + source: Box::new(self), + }, + } + } +} + +impl From for std::io::Error { + fn from(err: RetryError) -> Self { + use std::io::ErrorKind; + let kind = match err.status() { + Some(StatusCode::NOT_FOUND) => ErrorKind::NotFound, + Some(StatusCode::BAD_REQUEST) => ErrorKind::InvalidInput, + Some(StatusCode::UNAUTHORIZED) | Some(StatusCode::FORBIDDEN) => { + ErrorKind::PermissionDenied + } + _ => match &err.inner { + RequestError::Http(h) => match h.kind() { + HttpErrorKind::Timeout => ErrorKind::TimedOut, + HttpErrorKind::Connect => ErrorKind::NotConnected, + _ => ErrorKind::Other, + }, + _ => ErrorKind::Other, + }, + }; + Self::new(kind, err) + } +} + +pub(crate) type Result = std::result::Result; + +/// The configuration for how to respond to request errors +/// +/// The following categories of error will be retried: +/// +/// * 5xx server errors +/// * Connection errors +/// * Dropped connections +/// * Timeouts for [safe] / read-only requests +/// +/// Requests will be retried up to some limit, using exponential +/// backoff with jitter. See [`BackoffConfig`] for more information +/// +/// [safe]: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// The backoff configuration + pub backoff: BackoffConfig, + + /// The maximum number of times to retry a request + /// + /// Set to 0 to disable retries + pub max_retries: usize, + + /// The maximum length of time from the initial request + /// after which no further retries will be attempted + /// + /// This not only bounds the length of time before a server + /// error will be surfaced to the application, but also bounds + /// the length of time a request's credentials must remain valid. + /// + /// As requests are retried without renewing credentials or + /// regenerating request payloads, this number should be kept + /// below 5 minutes to avoid errors due to expired credentials + /// and/or request payloads + pub retry_timeout: Duration, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + backoff: Default::default(), + max_retries: 10, + retry_timeout: Duration::from_secs(3 * 60), + } + } +} + +fn body_contains_error(response_body: &str) -> bool { + response_body.contains("InternalError") || response_body.contains("SlowDown") +} + +pub(crate) struct RetryableRequest { + client: HttpClient, + request: HttpRequest, + + max_retries: usize, + retry_timeout: Duration, + backoff: Backoff, + + sensitive: bool, + idempotent: Option, + retry_on_conflict: bool, + payload: Option, + + retry_error_body: bool, +} + +impl RetryableRequest { + /// Set whether this request is idempotent + /// + /// An idempotent request will be retried on timeout even if the request + /// method is not [safe](https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1) + pub(crate) fn idempotent(self, idempotent: bool) -> Self { + Self { + idempotent: Some(idempotent), + ..self + } + } + + /// Set whether this request should be retried on a 409 Conflict response. + #[cfg(feature = "aws")] + pub(crate) fn retry_on_conflict(self, retry_on_conflict: bool) -> Self { + Self { + retry_on_conflict, + ..self + } + } + + /// Set whether this request contains sensitive data + /// + /// This will avoid printing out the URL in error messages + #[allow(unused)] + pub(crate) fn sensitive(self, sensitive: bool) -> Self { + Self { sensitive, ..self } + } + + /// Provide a [`PutPayload`] + pub(crate) fn payload(self, payload: Option) -> Self { + Self { payload, ..self } + } + + #[allow(unused)] + pub(crate) fn retry_error_body(self, retry_error_body: bool) -> Self { + Self { + retry_error_body, + ..self + } + } + + pub(crate) async fn send(self) -> Result { + let mut ctx = RetryContext { + retries: 0, + uri: (!self.sensitive).then(|| self.request.uri().clone()), + method: self.request.method().clone(), + max_retries: self.max_retries, + start: Instant::now(), + retry_timeout: self.retry_timeout, + }; + + let mut backoff = self.backoff; + let is_idempotent = self + .idempotent + .unwrap_or_else(|| self.request.method().is_safe()); + + loop { + let mut request = self.request.clone(); + + if let Some(payload) = &self.payload { + *request.body_mut() = payload.clone().into(); + } + + match self.client.execute(request).await { + Ok(r) => { + let status = r.status(); + if status.is_success() { + // For certain S3 requests, 200 response may contain `InternalError` or + // `SlowDown` in the message. These responses should be handled similarly + // to r5xx errors. + // More info here: https://repost.aws/knowledge-center/s3-resolve-200-internalerror + if !self.retry_error_body { + return Ok(r); + } + + let (parts, body) = r.into_parts(); + let body = match body.text().await { + Ok(body) => body, + Err(e) => return Err(ctx.err(RequestError::Http(e))), + }; + + if !body_contains_error(&body) { + // Success response and no error, clone and return response + return Ok(HttpResponse::from_parts(parts, body.into())); + } else { + // Retry as if this was a 5xx response + if ctx.exhausted() { + return Err(ctx.err(RequestError::Response { body, status })); + } + + let sleep = backoff.next(); + ctx.retries += 1; + info!( + "Encountered a response status of {} but body contains Error, backing off for {} seconds, retry {} of {}", + status, + sleep.as_secs_f32(), + ctx.retries, + ctx.max_retries, + ); + tokio::time::sleep(sleep).await; + } + } else if status == StatusCode::NOT_MODIFIED { + return Err(ctx.err(RequestError::Status { status, body: None })); + } else if status.is_redirection() { + let is_bare_redirect = !r.headers().contains_key(LOCATION); + return match is_bare_redirect { + true => Err(ctx.err(RequestError::BareRedirect)), + false => Err(ctx.err(RequestError::Status { + body: None, + status: r.status(), + })), + }; + } else { + let status = r.status(); + if ctx.exhausted() + || !(status.is_server_error() + || (self.retry_on_conflict && status == StatusCode::CONFLICT)) + { + let source = match status.is_client_error() { + true => match r.into_body().text().await { + Ok(body) => RequestError::Status { + status, + body: Some(body), + }, + Err(e) => RequestError::Http(e), + }, + false => RequestError::Status { status, body: None }, + }; + return Err(ctx.err(source)); + }; + + let sleep = backoff.next(); + ctx.retries += 1; + info!( + "Encountered server error, backing off for {} seconds, retry {} of {}", + sleep.as_secs_f32(), + ctx.retries, + ctx.max_retries, + ); + tokio::time::sleep(sleep).await; + } + } + Err(e) => { + // let e = sanitize_err(e); + + let do_retry = match e.kind() { + HttpErrorKind::Connect | HttpErrorKind::Request => true, // Request not sent, can retry + HttpErrorKind::Timeout | HttpErrorKind::Interrupted => is_idempotent, + HttpErrorKind::Unknown | HttpErrorKind::Decode => false, + }; + + if ctx.retries == ctx.max_retries + || ctx.start.elapsed() > ctx.retry_timeout + || !do_retry + { + return Err(ctx.err(RequestError::Http(e))); + } + let sleep = backoff.next(); + ctx.retries += 1; + info!( + "Encountered transport error backing off for {} seconds, retry {} of {}: {}", + sleep.as_secs_f32(), + ctx.retries, + ctx.max_retries, + e, + ); + tokio::time::sleep(sleep).await; + } + } + } + } +} + +pub(crate) trait RetryExt { + /// Return a [`RetryableRequest`] + fn retryable(self, config: &RetryConfig) -> RetryableRequest; + + /// Dispatch a request with the given retry configuration + /// + /// # Panic + /// + /// This will panic if the request body is a stream + fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>; +} + +impl RetryExt for HttpRequestBuilder { + fn retryable(self, config: &RetryConfig) -> RetryableRequest { + let (client, request) = self.into_parts(); + let request = request.expect("request must be valid"); + + RetryableRequest { + client, + request, + max_retries: config.max_retries, + retry_timeout: config.retry_timeout, + backoff: Backoff::new(&config.backoff), + idempotent: None, + payload: None, + sensitive: false, + retry_on_conflict: false, + retry_error_body: false, + } + } + + fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { + let request = self.retryable(config); + Box::pin(async move { request.send().await }) + } +} + +#[cfg(test)] +mod tests { + use crate::client::mock_server::MockServer; + use crate::client::retry::{body_contains_error, RequestError, RetryExt}; + use crate::client::HttpClient; + use crate::RetryConfig; + use hyper::header::LOCATION; + use hyper::Response; + use reqwest::{Client, Method, StatusCode}; + use std::time::Duration; + + #[test] + fn test_body_contains_error() { + // Example error message provided by https://repost.aws/knowledge-center/s3-resolve-200-internalerror + let error_response = "AmazonS3Exception: We encountered an internal error. Please try again. (Service: Amazon S3; Status Code: 200; Error Code: InternalError; Request ID: 0EXAMPLE9AAEB265)"; + assert!(body_contains_error(error_response)); + + let error_response_2 = "SlowDownPlease reduce your request rate.123456"; + assert!(body_contains_error(error_response_2)); + + // Example success response from https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html + let success_response = "2009-10-12T17:50:30.000Z\"9b2cf535f27731c974343645a3985328\""; + assert!(!body_contains_error(success_response)); + } + + #[tokio::test] + async fn test_retry() { + let mock = MockServer::new().await; + + let retry = RetryConfig { + backoff: Default::default(), + max_retries: 2, + retry_timeout: Duration::from_secs(1000), + }; + + let client = HttpClient::new( + Client::builder() + .timeout(Duration::from_millis(100)) + .build() + .unwrap(), + ); + + let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); + + // Simple request should work + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Returns client errors immediately with status message + mock.push( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("cupcakes".to_string()) + .unwrap(), + ); + + let e = do_request().await.unwrap_err(); + assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); + assert_eq!(e.body(), Some("cupcakes")); + assert_eq!( + e.inner().to_string(), + "Server returned non-2xx status code: 400 Bad Request: cupcakes" + ); + + // Handles client errors with no payload + mock.push( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("NAUGHTY NAUGHTY".to_string()) + .unwrap(), + ); + + let e = do_request().await.unwrap_err(); + assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); + assert_eq!(e.body(), Some("NAUGHTY NAUGHTY")); + assert_eq!( + e.inner().to_string(), + "Server returned non-2xx status code: 400 Bad Request: NAUGHTY NAUGHTY" + ); + + // Should retry server error request + mock.push( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(String::new()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Accepts 204 status code + mock.push( + Response::builder() + .status(StatusCode::NO_CONTENT) + .body(String::new()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::NO_CONTENT); + + // Follows 402 redirects + mock.push( + Response::builder() + .status(StatusCode::FOUND) + .header(LOCATION, "/foo") + .body(String::new()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Follows 401 redirects + mock.push( + Response::builder() + .status(StatusCode::FOUND) + .header(LOCATION, "/bar") + .body(String::new()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Handles redirect loop + for _ in 0..10 { + mock.push( + Response::builder() + .status(StatusCode::FOUND) + .header(LOCATION, "/bar") + .body(String::new()) + .unwrap(), + ); + } + + let e = do_request().await.unwrap_err().to_string(); + assert!(e.contains("error following redirect"), "{}", e); + + // Handles redirect missing location + mock.push( + Response::builder() + .status(StatusCode::FOUND) + .body(String::new()) + .unwrap(), + ); + + let e = do_request().await.unwrap_err(); + assert!(matches!(e.inner, RequestError::BareRedirect)); + assert_eq!(e.inner().to_string(), "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); + + // Gives up after the retrying the specified number of times + for _ in 0..=retry.max_retries { + mock.push( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body("ignored".to_string()) + .unwrap(), + ); + } + + let e = do_request().await.unwrap_err().to_string(); + assert!( + e.contains(" after 2 retries, max_retries: 2, retry_timeout: 1000s - Server returned non-2xx status code: 502 Bad Gateway"), + "{e}" + ); + + // Panic results in an incomplete message error in the client + mock.push_fn(|_| panic!()); + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Gives up after retrying multiple panics + for _ in 0..=retry.max_retries { + mock.push_fn(|_| panic!()); + } + let e = do_request().await.unwrap_err().to_string(); + assert!( + e.contains("after 2 retries, max_retries: 2, retry_timeout: 1000s - HTTP error: error sending request"), + "{e}" + ); + + // Retries on client timeout + mock.push_async_fn(|_| async move { + tokio::time::sleep(Duration::from_secs(10)).await; + panic!() + }); + do_request().await.unwrap(); + + // Does not retry PUT request + mock.push_async_fn(|_| async move { + tokio::time::sleep(Duration::from_secs(10)).await; + panic!() + }); + let res = client.request(Method::PUT, mock.url()).send_retry(&retry); + let e = res.await.unwrap_err().to_string(); + assert!( + !e.contains("retries") && e.contains("error sending request"), + "{e}" + ); + + let url = format!("{}/SENSITIVE", mock.url()); + for _ in 0..=retry.max_retries { + mock.push( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body("ignored".to_string()) + .unwrap(), + ); + } + let res = client.request(Method::GET, url).send_retry(&retry).await; + let err = res.unwrap_err().to_string(); + assert!(err.contains("SENSITIVE"), "{err}"); + + let url = format!("{}/SENSITIVE", mock.url()); + for _ in 0..=retry.max_retries { + mock.push( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body("ignored".to_string()) + .unwrap(), + ); + } + + // Sensitive requests should strip URL from error + let req = client + .request(Method::GET, &url) + .retryable(&retry) + .sensitive(true); + let err = req.send().await.unwrap_err().to_string(); + assert!(!err.contains("SENSITIVE"), "{err}"); + + for _ in 0..=retry.max_retries { + mock.push_fn(|_| panic!()); + } + + let req = client + .request(Method::GET, &url) + .retryable(&retry) + .sensitive(true); + let err = req.send().await.unwrap_err().to_string(); + assert!(!err.contains("SENSITIVE"), "{err}"); + + // Success response with error in body is retried + mock.push( + Response::builder() + .status(StatusCode::OK) + .body("InternalError".to_string()) + .unwrap(), + ); + let req = client + .request(Method::PUT, &url) + .retryable(&retry) + .idempotent(true) + .retry_error_body(true); + let r = req.send().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + // Response with InternalError should have been retried + let b = r.into_body().text().await.unwrap(); + assert!(!b.contains("InternalError")); + + // Should not retry success response with no error in body + mock.push( + Response::builder() + .status(StatusCode::OK) + .body("success".to_string()) + .unwrap(), + ); + let req = client + .request(Method::PUT, &url) + .retryable(&retry) + .idempotent(true) + .retry_error_body(true); + let r = req.send().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + let b = r.into_body().text().await.unwrap(); + assert!(b.contains("success")); + + // Shutdown + mock.shutdown().await + } +} diff --git a/src/client/s3.rs b/src/client/s3.rs new file mode 100644 index 0000000..a2221fb --- /dev/null +++ b/src/client/s3.rs @@ -0,0 +1,157 @@ +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The list and multipart API used by both GCS and S3 + +use crate::multipart::PartId; +use crate::path::Path; +use crate::{ListResult, ObjectMeta, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListResponse { + #[serde(default)] + pub contents: Vec, + #[serde(default)] + pub common_prefixes: Vec, + #[serde(default)] + pub next_continuation_token: Option, +} + +impl TryFrom for ListResult { + type Error = crate::Error; + + fn try_from(value: ListResponse) -> Result { + let common_prefixes = value + .common_prefixes + .into_iter() + .map(|x| Ok(Path::parse(x.prefix)?)) + .collect::>()?; + + let objects = value + .contents + .into_iter() + .map(TryFrom::try_from) + .collect::>()?; + + Ok(Self { + common_prefixes, + objects, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListPrefix { + pub prefix: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListContents { + pub key: String, + pub size: u64, + pub last_modified: DateTime, + #[serde(rename = "ETag")] + pub e_tag: Option, +} + +impl TryFrom for ObjectMeta { + type Error = crate::Error; + + fn try_from(value: ListContents) -> Result { + Ok(Self { + location: Path::parse(value.key)?, + last_modified: value.last_modified, + size: value.size, + e_tag: value.e_tag, + version: None, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub(crate) struct InitiateMultipartUploadResult { + pub upload_id: String, +} + +#[cfg(feature = "aws")] +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub(crate) struct CopyPartResult { + #[serde(rename = "ETag")] + pub e_tag: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "PascalCase")] +pub(crate) struct CompleteMultipartUpload { + pub part: Vec, +} + +#[derive(Serialize, Deserialize)] +pub(crate) struct PartMetadata { + pub e_tag: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub checksum_sha256: Option, +} + +impl From> for CompleteMultipartUpload { + fn from(value: Vec) -> Self { + let part = value + .into_iter() + .enumerate() + .map(|(part_idx, part)| { + let md = match quick_xml::de::from_str::(&part.content_id) { + Ok(md) => md, + // fallback to old way + Err(_) => PartMetadata { + e_tag: part.content_id.clone(), + checksum_sha256: None, + }, + }; + MultipartPart { + e_tag: md.e_tag, + part_number: part_idx + 1, + checksum_sha256: md.checksum_sha256, + } + }) + .collect(); + Self { part } + } +} + +#[derive(Debug, Serialize)] +pub(crate) struct MultipartPart { + #[serde(rename = "ETag")] + pub e_tag: String, + #[serde(rename = "PartNumber")] + pub part_number: usize, + #[serde(rename = "ChecksumSHA256")] + #[serde(skip_serializing_if = "Option::is_none")] + pub checksum_sha256: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub(crate) struct CompleteMultipartUploadResult { + #[serde(rename = "ETag")] + pub e_tag: String, +} diff --git a/src/client/token.rs b/src/client/token.rs new file mode 100644 index 0000000..81ffc11 --- /dev/null +++ b/src/client/token.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +/// A temporary authentication token with an associated expiry +#[derive(Debug, Clone)] +pub(crate) struct TemporaryToken { + /// The temporary credential + pub token: T, + /// The instant at which this credential is no longer valid + /// None means the credential does not expire + pub expiry: Option, +} + +/// Provides [`TokenCache::get_or_insert_with`] which can be used to cache a +/// [`TemporaryToken`] based on its expiry +#[derive(Debug)] +pub(crate) struct TokenCache { + cache: Mutex, Instant)>>, + min_ttl: Duration, + fetch_backoff: Duration, +} + +impl Default for TokenCache { + fn default() -> Self { + Self { + cache: Default::default(), + min_ttl: Duration::from_secs(300), + // How long to wait before re-attempting a token fetch after receiving one that + // is still within the min-ttl + fetch_backoff: Duration::from_millis(100), + } + } +} + +impl TokenCache { + /// Override the minimum remaining TTL for a cached token to be used + #[cfg(any(feature = "aws", feature = "gcp"))] + pub(crate) fn with_min_ttl(self, min_ttl: Duration) -> Self { + Self { min_ttl, ..self } + } + + pub(crate) async fn get_or_insert_with(&self, f: F) -> Result + where + F: FnOnce() -> Fut + Send, + Fut: Future, E>> + Send, + { + let now = Instant::now(); + let mut locked = self.cache.lock().await; + + if let Some((cached, fetched_at)) = locked.as_ref() { + match cached.expiry { + Some(ttl) => { + if ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl || + // if we've recently attempted to fetch this token and it's not actually + // expired, we'll wait to re-fetch it and return the cached one + (fetched_at.elapsed() < self.fetch_backoff && ttl.checked_duration_since(now).is_some()) + { + return Ok(cached.token.clone()); + } + } + None => return Ok(cached.token.clone()), + } + } + + let cached = f().await?; + let token = cached.token.clone(); + *locked = Some((cached, Instant::now())); + + Ok(token) + } +} + +#[cfg(test)] +mod test { + use crate::client::token::{TemporaryToken, TokenCache}; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::time::{Duration, Instant}; + + // Helper function to create a token with a specific expiry duration from now + fn create_token(expiry_duration: Option) -> TemporaryToken { + TemporaryToken { + token: "test_token".to_string(), + expiry: expiry_duration.map(|d| Instant::now() + d), + } + } + + #[tokio::test] + async fn test_expired_token_is_refreshed() { + let cache = TokenCache::default(); + static COUNTER: AtomicU32 = AtomicU32::new(0); + + async fn get_token() -> Result, String> { + COUNTER.fetch_add(1, Ordering::SeqCst); + Ok::<_, String>(create_token(Some(Duration::from_secs(0)))) + } + + // Should fetch initial token + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + tokio::time::sleep(Duration::from_millis(2)).await; + + // Token is expired, so should fetch again + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_min_ttl_causes_refresh() { + let cache = TokenCache { + cache: Default::default(), + min_ttl: Duration::from_secs(1), + fetch_backoff: Duration::from_millis(1), + }; + + static COUNTER: AtomicU32 = AtomicU32::new(0); + + async fn get_token() -> Result, String> { + COUNTER.fetch_add(1, Ordering::SeqCst); + Ok::<_, String>(create_token(Some(Duration::from_millis(100)))) + } + + // Initial fetch + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + // Should not fetch again since not expired and within fetch_backoff + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + tokio::time::sleep(Duration::from_millis(2)).await; + + // Should fetch, since we've passed fetch_backoff + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 2); + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..29a389d --- /dev/null +++ b/src/config.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use std::fmt::{Debug, Display, Formatter}; +use std::str::FromStr; +use std::time::Duration; + +use humantime::{format_duration, parse_duration}; +use reqwest::header::HeaderValue; + +use crate::{Error, Result}; + +/// Provides deferred parsing of a value +/// +/// This allows builders to defer fallibility to build +#[derive(Debug, Clone)] +pub(crate) enum ConfigValue { + Parsed(T), + Deferred(String), +} + +impl Display for ConfigValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Parsed(v) => write!(f, "{v}"), + Self::Deferred(v) => write!(f, "{v}"), + } + } +} + +impl From for ConfigValue { + fn from(value: T) -> Self { + Self::Parsed(value) + } +} + +impl ConfigValue { + pub(crate) fn parse(&mut self, v: impl Into) { + *self = Self::Deferred(v.into()) + } + + pub(crate) fn get(&self) -> Result { + match self { + Self::Parsed(v) => Ok(v.clone()), + Self::Deferred(v) => T::parse(v), + } + } +} + +impl Default for ConfigValue { + fn default() -> Self { + Self::Parsed(T::default()) + } +} + +/// A value that can be stored in [`ConfigValue`] +pub(crate) trait Parse: Sized { + fn parse(v: &str) -> Result; +} + +impl Parse for bool { + fn parse(v: &str) -> Result { + let lower = v.to_ascii_lowercase(); + match lower.as_str() { + "1" | "true" | "on" | "yes" | "y" => Ok(true), + "0" | "false" | "off" | "no" | "n" => Ok(false), + _ => Err(Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as boolean").into(), + }), + } + } +} + +impl Parse for Duration { + fn parse(v: &str) -> Result { + parse_duration(v).map_err(|_| Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as Duration").into(), + }) + } +} + +impl Parse for usize { + fn parse(v: &str) -> Result { + Self::from_str(v).map_err(|_| Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as usize").into(), + }) + } +} + +impl Parse for u32 { + fn parse(v: &str) -> Result { + Self::from_str(v).map_err(|_| Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as u32").into(), + }) + } +} + +impl Parse for HeaderValue { + fn parse(v: &str) -> Result { + Self::from_str(v).map_err(|_| Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as HeaderValue").into(), + }) + } +} + +pub(crate) fn fmt_duration(duration: &ConfigValue) -> String { + match duration { + ConfigValue::Parsed(v) => format_duration(*v).to_string(), + ConfigValue::Deferred(v) => v.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_parse_duration() { + let duration = Duration::from_secs(60); + assert_eq!(Duration::parse("60 seconds").unwrap(), duration); + assert_eq!(Duration::parse("60 s").unwrap(), duration); + assert_eq!(Duration::parse("60s").unwrap(), duration) + } +} diff --git a/src/delimited.rs b/src/delimited.rs new file mode 100644 index 0000000..5b11a0b --- /dev/null +++ b/src/delimited.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility for streaming newline delimited files from object storage + +use std::collections::VecDeque; + +use bytes::Bytes; +use futures::{Stream, StreamExt}; + +use super::Result; + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("encountered unterminated string")] + UnterminatedString, + + #[error("encountered trailing escape character")] + TrailingEscape, +} + +impl From for super::Error { + fn from(err: Error) -> Self { + Self::Generic { + store: "LineDelimiter", + source: Box::new(err), + } + } +} + +/// The ASCII encoding of `"` +const QUOTE: u8 = b'"'; + +/// The ASCII encoding of `\n` +const NEWLINE: u8 = b'\n'; + +/// The ASCII encoding of `\` +const ESCAPE: u8 = b'\\'; + +/// [`LineDelimiter`] is provided with a stream of [`Bytes`] and returns an iterator +/// of [`Bytes`] containing a whole number of new line delimited records +#[derive(Debug, Default)] +struct LineDelimiter { + /// Complete chunks of [`Bytes`] + complete: VecDeque, + /// Remainder bytes that form the next record + remainder: Vec, + /// True if the last character was the escape character + is_escape: bool, + /// True if currently processing a quoted string + is_quote: bool, +} + +impl LineDelimiter { + /// Creates a new [`LineDelimiter`] with the provided delimiter + fn new() -> Self { + Self::default() + } + + /// Adds the next set of [`Bytes`] + fn push(&mut self, val: impl Into) { + let val: Bytes = val.into(); + + let is_escape = &mut self.is_escape; + let is_quote = &mut self.is_quote; + let mut record_ends = val.iter().enumerate().filter_map(|(idx, v)| { + if *is_escape { + *is_escape = false; + None + } else if *v == ESCAPE { + *is_escape = true; + None + } else if *v == QUOTE { + *is_quote = !*is_quote; + None + } else if *is_quote { + None + } else { + (*v == NEWLINE).then_some(idx + 1) + } + }); + + let start_offset = match self.remainder.is_empty() { + true => 0, + false => match record_ends.next() { + Some(idx) => { + self.remainder.extend_from_slice(&val[0..idx]); + self.complete + .push_back(Bytes::from(std::mem::take(&mut self.remainder))); + idx + } + None => { + self.remainder.extend_from_slice(&val); + return; + } + }, + }; + let end_offset = record_ends.last().unwrap_or(start_offset); + if start_offset != end_offset { + self.complete.push_back(val.slice(start_offset..end_offset)); + } + + if end_offset != val.len() { + self.remainder.extend_from_slice(&val[end_offset..]) + } + } + + /// Marks the end of the stream, delimiting any remaining bytes + /// + /// Returns `true` if there is no remaining data to be read + fn finish(&mut self) -> Result { + if !self.remainder.is_empty() { + if self.is_quote { + Err(Error::UnterminatedString)?; + } + if self.is_escape { + Err(Error::TrailingEscape)?; + } + + self.complete + .push_back(Bytes::from(std::mem::take(&mut self.remainder))) + } + Ok(self.complete.is_empty()) + } +} + +impl Iterator for LineDelimiter { + type Item = Bytes; + + fn next(&mut self) -> Option { + self.complete.pop_front() + } +} + +/// Given a [`Stream`] of [`Bytes`] returns a [`Stream`] where each +/// yielded [`Bytes`] contains a whole number of new line delimited records +/// accounting for `\` style escapes and `"` quotes +pub fn newline_delimited_stream(s: S) -> impl Stream> +where + S: Stream> + Unpin, +{ + let delimiter = LineDelimiter::new(); + + futures::stream::unfold( + (s, delimiter, false), + |(mut s, mut delimiter, mut exhausted)| async move { + loop { + if let Some(next) = delimiter.next() { + return Some((Ok(next), (s, delimiter, exhausted))); + } else if exhausted { + return None; + } + + match s.next().await { + Some(Ok(bytes)) => delimiter.push(bytes), + Some(Err(e)) => return Some((Err(e), (s, delimiter, exhausted))), + None => { + exhausted = true; + match delimiter.finish() { + Ok(true) => return None, + Ok(false) => continue, + Err(e) => return Some((Err(e), (s, delimiter, exhausted))), + } + } + } + } + }, + ) +} + +#[cfg(test)] +mod tests { + use futures::stream::{BoxStream, TryStreamExt}; + + use super::*; + + #[test] + fn test_delimiter() { + let mut delimiter = LineDelimiter::new(); + delimiter.push("hello\nworld"); + delimiter.push("\n\n"); + + assert_eq!(delimiter.next().unwrap(), Bytes::from("hello\n")); + assert_eq!(delimiter.next().unwrap(), Bytes::from("world\n")); + assert_eq!(delimiter.next().unwrap(), Bytes::from("\n")); + assert!(delimiter.next().is_none()); + } + + #[test] + fn test_delimiter_escaped() { + let mut delimiter = LineDelimiter::new(); + delimiter.push(""); + delimiter.push("fo\\\n\"foo"); + delimiter.push("bo\n\"bar\n"); + delimiter.push("\"he"); + delimiter.push("llo\"\n"); + assert_eq!( + delimiter.next().unwrap(), + Bytes::from("fo\\\n\"foobo\n\"bar\n") + ); + assert_eq!(delimiter.next().unwrap(), Bytes::from("\"hello\"\n")); + assert!(delimiter.next().is_none()); + + // Verify can push further data + delimiter.push("\"foo\nbar\",\"fiz\\\"inner\\\"\"\nhello"); + assert!(!delimiter.finish().unwrap()); + + assert_eq!( + delimiter.next().unwrap(), + Bytes::from("\"foo\nbar\",\"fiz\\\"inner\\\"\"\n") + ); + assert_eq!(delimiter.next().unwrap(), Bytes::from("hello")); + assert!(delimiter.finish().unwrap()); + assert!(delimiter.next().is_none()); + } + + #[tokio::test] + async fn test_delimiter_stream() { + let input = vec!["hello\nworld\nbin", "go\ncup", "cakes"]; + let input_stream = futures::stream::iter(input.into_iter().map(|s| Ok(Bytes::from(s)))); + let stream = newline_delimited_stream(input_stream); + + let results: Vec<_> = stream.try_collect().await.unwrap(); + assert_eq!( + results, + vec![ + Bytes::from("hello\nworld\n"), + Bytes::from("bingo\n"), + Bytes::from("cupcakes") + ] + ) + } + #[tokio::test] + async fn test_delimiter_unfold_stream() { + let input_stream: BoxStream<'static, Result> = futures::stream::unfold( + VecDeque::from(["hello\nworld\nbin", "go\ncup", "cakes"]), + |mut input| async move { + if !input.is_empty() { + Some((Ok(Bytes::from(input.pop_front().unwrap())), input)) + } else { + None + } + }, + ) + .boxed(); + let stream = newline_delimited_stream(input_stream); + + let results: Vec<_> = stream.try_collect().await.unwrap(); + assert_eq!( + results, + vec![ + Bytes::from("hello\nworld\n"), + Bytes::from("bingo\n"), + Bytes::from("cupcakes") + ] + ) + } +} diff --git a/src/gcp/builder.rs b/src/gcp/builder.rs new file mode 100644 index 0000000..74aecae --- /dev/null +++ b/src/gcp/builder.rs @@ -0,0 +1,718 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::{http_connector, HttpConnector, TokenCredentialProvider}; +use crate::gcp::client::{GoogleCloudStorageClient, GoogleCloudStorageConfig}; +use crate::gcp::credential::{ + ApplicationDefaultCredentials, InstanceCredentialProvider, ServiceAccountCredentials, + DEFAULT_GCS_BASE_URL, +}; +use crate::gcp::{ + credential, GcpCredential, GcpCredentialProvider, GcpSigningCredential, + GcpSigningCredentialProvider, GoogleCloudStorage, STORE, +}; +use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider}; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use url::Url; + +use super::credential::{AuthorizedUserSigningCredentials, InstanceSigningCredentialProvider}; + +const TOKEN_MIN_TTL: Duration = Duration::from_secs(4 * 60); + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Missing bucket name")] + MissingBucketName {}, + + #[error("One of service account path or service account key may be provided.")] + ServiceAccountPathAndKeyProvided, + + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[error( + "Unknown url scheme cannot be parsed into storage location: {}", + scheme + )] + UnknownUrlScheme { scheme: String }, + + #[error("URL did not match any known pattern for scheme: {}", url)] + UrlNotRecognised { url: String }, + + #[error("Configuration key: '{}' is not known.", key)] + UnknownConfigurationKey { key: String }, + + #[error("GCP credential error: {}", source)] + Credential { source: credential::Error }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::UnknownConfigurationKey { key } => { + Self::UnknownConfigurationKey { store: STORE, key } + } + _ => Self::Generic { + store: STORE, + source: Box::new(err), + }, + } + } +} + +/// Configure a connection to Google Cloud Storage. +/// +/// If no credentials are explicitly provided, they will be sourced +/// from the environment as documented [here](https://cloud.google.com/docs/authentication/application-default-credentials). +/// +/// # Example +/// ``` +/// # let BUCKET_NAME = "foo"; +/// # use object_store::gcp::GoogleCloudStorageBuilder; +/// let gcs = GoogleCloudStorageBuilder::from_env().with_bucket_name(BUCKET_NAME).build(); +/// ``` +#[derive(Debug, Clone)] +pub struct GoogleCloudStorageBuilder { + /// Bucket name + bucket_name: Option, + /// Url + url: Option, + /// Path to the service account file + service_account_path: Option, + /// The serialized service account key + service_account_key: Option, + /// Path to the application credentials file. + application_credentials_path: Option, + /// Retry config + retry_config: RetryConfig, + /// Client options + client_options: ClientOptions, + /// Credentials + credentials: Option, + /// Credentials for sign url + signing_credentials: Option, + /// The [`HttpConnector`] to use + http_connector: Option>, +} + +/// Configuration keys for [`GoogleCloudStorageBuilder`] +/// +/// Configuration via keys can be done via [`GoogleCloudStorageBuilder::with_config`] +/// +/// # Example +/// ``` +/// # use object_store::gcp::{GoogleCloudStorageBuilder, GoogleConfigKey}; +/// let builder = GoogleCloudStorageBuilder::new() +/// .with_config("google_service_account".parse().unwrap(), "my-service-account") +/// .with_config(GoogleConfigKey::Bucket, "my-bucket"); +/// ``` +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] +#[non_exhaustive] +pub enum GoogleConfigKey { + /// Path to the service account file + /// + /// Supported keys: + /// - `google_service_account` + /// - `service_account` + /// - `google_service_account_path` + /// - `service_account_path` + ServiceAccount, + + /// The serialized service account key. + /// + /// Supported keys: + /// - `google_service_account_key` + /// - `service_account_key` + ServiceAccountKey, + + /// Bucket name + /// + /// See [`GoogleCloudStorageBuilder::with_bucket_name`] for details. + /// + /// Supported keys: + /// - `google_bucket` + /// - `google_bucket_name` + /// - `bucket` + /// - `bucket_name` + Bucket, + + /// Application credentials path + /// + /// See [`GoogleCloudStorageBuilder::with_application_credentials`]. + ApplicationCredentials, + + /// Client options + Client(ClientConfigKey), +} + +impl AsRef for GoogleConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::ServiceAccount => "google_service_account", + Self::ServiceAccountKey => "google_service_account_key", + Self::Bucket => "google_bucket", + Self::ApplicationCredentials => "google_application_credentials", + Self::Client(key) => key.as_ref(), + } + } +} + +impl FromStr for GoogleConfigKey { + type Err = crate::Error; + + fn from_str(s: &str) -> Result { + match s { + "google_service_account" + | "service_account" + | "google_service_account_path" + | "service_account_path" => Ok(Self::ServiceAccount), + "google_service_account_key" | "service_account_key" => Ok(Self::ServiceAccountKey), + "google_bucket" | "google_bucket_name" | "bucket" | "bucket_name" => Ok(Self::Bucket), + "google_application_credentials" => Ok(Self::ApplicationCredentials), + _ => match s.strip_prefix("google_").unwrap_or(s).parse() { + Ok(key) => Ok(Self::Client(key)), + Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), + }, + } + } +} + +impl Default for GoogleCloudStorageBuilder { + fn default() -> Self { + Self { + bucket_name: None, + service_account_path: None, + service_account_key: None, + application_credentials_path: None, + retry_config: Default::default(), + client_options: ClientOptions::new().with_allow_http(true), + url: None, + credentials: None, + signing_credentials: None, + http_connector: None, + } + } +} + +impl GoogleCloudStorageBuilder { + /// Create a new [`GoogleCloudStorageBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Create an instance of [`GoogleCloudStorageBuilder`] with values pre-populated from environment variables. + /// + /// Variables extracted from environment: + /// * GOOGLE_SERVICE_ACCOUNT: location of service account file + /// * GOOGLE_SERVICE_ACCOUNT_PATH: (alias) location of service account file + /// * SERVICE_ACCOUNT: (alias) location of service account file + /// * GOOGLE_SERVICE_ACCOUNT_KEY: JSON serialized service account key + /// * GOOGLE_BUCKET: bucket name + /// * GOOGLE_BUCKET_NAME: (alias) bucket name + /// + /// # Example + /// ``` + /// use object_store::gcp::GoogleCloudStorageBuilder; + /// + /// let gcs = GoogleCloudStorageBuilder::from_env() + /// .with_bucket_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder = Self::default(); + + if let Ok(service_account_path) = std::env::var("SERVICE_ACCOUNT") { + builder.service_account_path = Some(service_account_path); + } + + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if key.starts_with("GOOGLE_") { + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + builder = builder.with_config(config_key, value); + } + } + } + } + + builder + } + + /// Parse available connection info form a well-known storage URL. + /// + /// The supported url schemes are: + /// + /// - `gs:///` + /// + /// Note: Settings derived from the URL will override any others set on this builder + /// + /// # Example + /// ``` + /// use object_store::gcp::GoogleCloudStorageBuilder; + /// + /// let gcs = GoogleCloudStorageBuilder::from_env() + /// .with_url("gs://bucket/path") + /// .build(); + /// ``` + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set an option on the builder via a key - value pair. + pub fn with_config(mut self, key: GoogleConfigKey, value: impl Into) -> Self { + match key { + GoogleConfigKey::ServiceAccount => self.service_account_path = Some(value.into()), + GoogleConfigKey::ServiceAccountKey => self.service_account_key = Some(value.into()), + GoogleConfigKey::Bucket => self.bucket_name = Some(value.into()), + GoogleConfigKey::ApplicationCredentials => { + self.application_credentials_path = Some(value.into()) + } + GoogleConfigKey::Client(key) => { + self.client_options = self.client_options.with_config(key, value) + } + }; + self + } + + /// Get config value via a [`GoogleConfigKey`]. + /// + /// # Example + /// ``` + /// use object_store::gcp::{GoogleCloudStorageBuilder, GoogleConfigKey}; + /// + /// let builder = GoogleCloudStorageBuilder::from_env() + /// .with_service_account_key("foo"); + /// let service_account_key = builder.get_config_value(&GoogleConfigKey::ServiceAccountKey).unwrap_or_default(); + /// assert_eq!("foo", &service_account_key); + /// ``` + pub fn get_config_value(&self, key: &GoogleConfigKey) -> Option { + match key { + GoogleConfigKey::ServiceAccount => self.service_account_path.clone(), + GoogleConfigKey::ServiceAccountKey => self.service_account_key.clone(), + GoogleConfigKey::Bucket => self.bucket_name.clone(), + GoogleConfigKey::ApplicationCredentials => self.application_credentials_path.clone(), + GoogleConfigKey::Client(key) => self.client_options.get_config_value(key), + } + } + + /// Sets properties on this builder based on a URL + /// + /// This is a separate member function to allow fallible computation to + /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] + fn parse_url(&mut self, url: &str) -> Result<()> { + let parsed = Url::parse(url).map_err(|source| Error::UnableToParseUrl { + source, + url: url.to_string(), + })?; + + let host = parsed.host_str().ok_or_else(|| Error::UrlNotRecognised { + url: url.to_string(), + })?; + + match parsed.scheme() { + "gs" => self.bucket_name = Some(host.to_string()), + scheme => { + let scheme = scheme.to_string(); + return Err(Error::UnknownUrlScheme { scheme }.into()); + } + } + Ok(()) + } + + /// Set the bucket name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Set the path to the service account file. + /// + /// This or [`GoogleCloudStorageBuilder::with_service_account_key`] must be + /// set. + /// + /// Example `"/tmp/gcs.json"`. + /// + /// Example contents of `gcs.json`: + /// + /// ```json + /// { + /// "gcs_base_url": "https://localhost:4443", + /// "disable_oauth": true, + /// "client_email": "", + /// "private_key": "" + /// } + /// ``` + pub fn with_service_account_path(mut self, service_account_path: impl Into) -> Self { + self.service_account_path = Some(service_account_path.into()); + self + } + + /// Set the service account key. The service account must be in the JSON + /// format. + /// + /// This or [`GoogleCloudStorageBuilder::with_service_account_path`] must be + /// set. + pub fn with_service_account_key(mut self, service_account: impl Into) -> Self { + self.service_account_key = Some(service_account.into()); + self + } + + /// Set the path to the application credentials file. + /// + /// + pub fn with_application_credentials( + mut self, + application_credentials_path: impl Into, + ) -> Self { + self.application_credentials_path = Some(application_credentials_path.into()); + self + } + + /// Set the credential provider overriding any other options + pub fn with_credentials(mut self, credentials: GcpCredentialProvider) -> Self { + self.credentials = Some(credentials); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Set the proxy_url to be used by the underlying client + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_url(proxy_url); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate(mut self, proxy_ca_certificate: impl Into) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// The [`HttpConnector`] to use + /// + /// On non-WASM32 platforms uses [`reqwest`] by default, on WASM32 platforms must be provided + pub fn with_http_connector(mut self, connector: C) -> Self { + self.http_connector = Some(Arc::new(connector)); + self + } + + /// Configure a connection to Google Cloud Storage, returning a + /// new [`GoogleCloudStorage`] and consuming `self` + pub fn build(mut self) -> Result { + if let Some(url) = self.url.take() { + self.parse_url(&url)?; + } + + let bucket_name = self.bucket_name.ok_or(Error::MissingBucketName {})?; + + let http = http_connector(self.http_connector)?; + + // First try to initialize from the service account information. + let service_account_credentials = + match (self.service_account_path, self.service_account_key) { + (Some(path), None) => Some( + ServiceAccountCredentials::from_file(path) + .map_err(|source| Error::Credential { source })?, + ), + (None, Some(key)) => Some( + ServiceAccountCredentials::from_key(&key) + .map_err(|source| Error::Credential { source })?, + ), + (None, None) => None, + (Some(_), Some(_)) => return Err(Error::ServiceAccountPathAndKeyProvided.into()), + }; + + // Then try to initialize from the application credentials file, or the environment. + let application_default_credentials = + ApplicationDefaultCredentials::read(self.application_credentials_path.as_deref())?; + + let disable_oauth = service_account_credentials + .as_ref() + .map(|c| c.disable_oauth) + .unwrap_or(false); + + let gcs_base_url: String = service_account_credentials + .as_ref() + .and_then(|c| c.gcs_base_url.clone()) + .unwrap_or_else(|| DEFAULT_GCS_BASE_URL.to_string()); + + let credentials = if let Some(credentials) = self.credentials { + credentials + } else if disable_oauth { + Arc::new(StaticCredentialProvider::new(GcpCredential { + bearer: "".to_string(), + })) as _ + } else if let Some(credentials) = service_account_credentials.clone() { + Arc::new(TokenCredentialProvider::new( + credentials.token_provider()?, + http.connect(&self.client_options)?, + self.retry_config.clone(), + )) as _ + } else if let Some(credentials) = application_default_credentials.clone() { + match credentials { + ApplicationDefaultCredentials::AuthorizedUser(token) => Arc::new( + TokenCredentialProvider::new( + token, + http.connect(&self.client_options)?, + self.retry_config.clone(), + ) + .with_min_ttl(TOKEN_MIN_TTL), + ) as _, + ApplicationDefaultCredentials::ServiceAccount(token) => { + Arc::new(TokenCredentialProvider::new( + token.token_provider()?, + http.connect(&self.client_options)?, + self.retry_config.clone(), + )) as _ + } + } + } else { + Arc::new( + TokenCredentialProvider::new( + InstanceCredentialProvider::default(), + http.connect(&self.client_options.metadata_options())?, + self.retry_config.clone(), + ) + .with_min_ttl(TOKEN_MIN_TTL), + ) as _ + }; + + let signing_credentials = if let Some(signing_credentials) = self.signing_credentials { + signing_credentials + } else if disable_oauth { + Arc::new(StaticCredentialProvider::new(GcpSigningCredential { + email: "".to_string(), + private_key: None, + })) as _ + } else if let Some(credentials) = service_account_credentials.clone() { + credentials.signing_credentials()? + } else if let Some(credentials) = application_default_credentials.clone() { + match credentials { + ApplicationDefaultCredentials::AuthorizedUser(token) => { + Arc::new(TokenCredentialProvider::new( + AuthorizedUserSigningCredentials::from(token)?, + http.connect(&self.client_options)?, + self.retry_config.clone(), + )) as _ + } + ApplicationDefaultCredentials::ServiceAccount(token) => { + token.signing_credentials()? + } + } + } else { + Arc::new(TokenCredentialProvider::new( + InstanceSigningCredentialProvider::default(), + http.connect(&self.client_options.metadata_options())?, + self.retry_config.clone(), + )) as _ + }; + + let config = GoogleCloudStorageConfig::new( + gcs_base_url, + credentials, + signing_credentials, + bucket_name, + self.retry_config, + self.client_options, + ); + + let http_client = http.connect(&config.client_options)?; + Ok(GoogleCloudStorage { + client: Arc::new(GoogleCloudStorageClient::new(config, http_client)?), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use std::io::Write; + use tempfile::NamedTempFile; + + const FAKE_KEY: &str = r#"{"private_key": "private_key", "private_key_id": "private_key_id", "client_email":"client_email", "disable_oauth":true}"#; + + #[test] + fn gcs_test_service_account_key_and_path() { + let mut tfile = NamedTempFile::new().unwrap(); + write!(tfile, "{FAKE_KEY}").unwrap(); + let _ = GoogleCloudStorageBuilder::new() + .with_service_account_key(FAKE_KEY) + .with_service_account_path(tfile.path().to_str().unwrap()) + .with_bucket_name("foo") + .build() + .unwrap_err(); + } + + #[test] + fn gcs_test_config_from_map() { + let google_service_account = "object_store:fake_service_account".to_string(); + let google_bucket_name = "object_store:fake_bucket".to_string(); + let options = HashMap::from([ + ("google_service_account", google_service_account.clone()), + ("google_bucket_name", google_bucket_name.clone()), + ]); + + let builder = options + .iter() + .fold(GoogleCloudStorageBuilder::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }); + + assert_eq!( + builder.service_account_path.unwrap(), + google_service_account.as_str() + ); + assert_eq!(builder.bucket_name.unwrap(), google_bucket_name.as_str()); + } + + #[test] + fn gcs_test_config_aliases() { + // Service account path + for alias in [ + "google_service_account", + "service_account", + "google_service_account_path", + "service_account_path", + ] { + let builder = GoogleCloudStorageBuilder::new() + .with_config(alias.parse().unwrap(), "/fake/path.json"); + assert_eq!("/fake/path.json", builder.service_account_path.unwrap()); + } + + // Service account key + for alias in ["google_service_account_key", "service_account_key"] { + let builder = + GoogleCloudStorageBuilder::new().with_config(alias.parse().unwrap(), FAKE_KEY); + assert_eq!(FAKE_KEY, builder.service_account_key.unwrap()); + } + + // Bucket name + for alias in [ + "google_bucket", + "google_bucket_name", + "bucket", + "bucket_name", + ] { + let builder = + GoogleCloudStorageBuilder::new().with_config(alias.parse().unwrap(), "fake_bucket"); + assert_eq!("fake_bucket", builder.bucket_name.unwrap()); + } + } + + #[tokio::test] + async fn gcs_test_proxy_url() { + let mut tfile = NamedTempFile::new().unwrap(); + write!(tfile, "{FAKE_KEY}").unwrap(); + let service_account_path = tfile.path(); + let gcs = GoogleCloudStorageBuilder::new() + .with_service_account_path(service_account_path.to_str().unwrap()) + .with_bucket_name("foo") + .with_proxy_url("https://example.com") + .build(); + assert!(gcs.is_ok()); + + let err = GoogleCloudStorageBuilder::new() + .with_service_account_path(service_account_path.to_str().unwrap()) + .with_bucket_name("foo") + .with_proxy_url("asdf://example.com") + .build() + .unwrap_err() + .to_string(); + + assert_eq!("Generic HTTP client error: builder error", err); + } + + #[test] + fn gcs_test_urls() { + let mut builder = GoogleCloudStorageBuilder::new(); + builder.parse_url("gs://bucket/path").unwrap(); + assert_eq!(builder.bucket_name.as_deref(), Some("bucket")); + + builder.parse_url("gs://bucket.mydomain/path").unwrap(); + assert_eq!(builder.bucket_name.as_deref(), Some("bucket.mydomain")); + + builder.parse_url("mailto://bucket/path").unwrap_err(); + } + + #[test] + fn gcs_test_service_account_key_only() { + let _ = GoogleCloudStorageBuilder::new() + .with_service_account_key(FAKE_KEY) + .with_bucket_name("foo") + .build() + .unwrap(); + } + + #[test] + fn gcs_test_config_get_value() { + let google_service_account = "object_store:fake_service_account".to_string(); + let google_bucket_name = "object_store:fake_bucket".to_string(); + let builder = GoogleCloudStorageBuilder::new() + .with_config(GoogleConfigKey::ServiceAccount, &google_service_account) + .with_config(GoogleConfigKey::Bucket, &google_bucket_name); + + assert_eq!( + builder + .get_config_value(&GoogleConfigKey::ServiceAccount) + .unwrap(), + google_service_account + ); + assert_eq!( + builder.get_config_value(&GoogleConfigKey::Bucket).unwrap(), + google_bucket_name + ); + } + + #[test] + fn gcp_test_client_opts() { + let key = "GOOGLE_PROXY_URL"; + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + assert_eq!( + GoogleConfigKey::Client(ClientConfigKey::ProxyUrl), + config_key + ); + } else { + panic!("{} not propagated as ClientConfigKey", key); + } + } +} diff --git a/src/gcp/client.rs b/src/gcp/client.rs new file mode 100644 index 0000000..1cc7296 --- /dev/null +++ b/src/gcp/client.rs @@ -0,0 +1,716 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::builder::HttpRequestBuilder; +use crate::client::get::GetClient; +use crate::client::header::{get_put_result, get_version, HeaderConfig}; +use crate::client::list::ListClient; +use crate::client::retry::RetryExt; +use crate::client::s3::{ + CompleteMultipartUpload, CompleteMultipartUploadResult, InitiateMultipartUploadResult, + ListResponse, +}; +use crate::client::{GetOptionsExt, HttpClient, HttpError, HttpResponse}; +use crate::gcp::{GcpCredential, GcpCredentialProvider, GcpSigningCredentialProvider, STORE}; +use crate::multipart::PartId; +use crate::path::{Path, DELIMITER}; +use crate::util::hex_encode; +use crate::{ + Attribute, Attributes, ClientOptions, GetOptions, ListResult, MultipartId, PutMode, + PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, RetryConfig, +}; +use async_trait::async_trait; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Buf; +use http::header::{ + CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, + CONTENT_TYPE, +}; +use http::{HeaderName, Method, StatusCode}; +use percent_encoding::{percent_encode, utf8_percent_encode, NON_ALPHANUMERIC}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +const VERSION_HEADER: &str = "x-goog-generation"; +const DEFAULT_CONTENT_TYPE: &str = "application/octet-stream"; +const USER_DEFINED_METADATA_HEADER_PREFIX: &str = "x-goog-meta-"; + +static VERSION_MATCH: HeaderName = HeaderName::from_static("x-goog-if-generation-match"); + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Error performing list request: {}", source)] + ListRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting list response body: {}", source)] + ListResponseBody { source: HttpError }, + + #[error("Got invalid list response: {}", source)] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[error("Error performing get request {}: {}", path, source)] + GetRequest { + source: crate::client::retry::RetryError, + path: String, + }, + + #[error("Error performing request {}: {}", path, source)] + Request { + source: crate::client::retry::RetryError, + path: String, + }, + + #[error("Error getting put response body: {}", source)] + PutResponseBody { source: HttpError }, + + #[error("Got invalid put request: {}", source)] + InvalidPutRequest { source: quick_xml::se::SeError }, + + #[error("Got invalid put response: {}", source)] + InvalidPutResponse { source: quick_xml::de::DeError }, + + #[error("Unable to extract metadata from headers: {}", source)] + Metadata { + source: crate::client::header::Error, + }, + + #[error("Version required for conditional update")] + MissingVersion, + + #[error("Error performing complete multipart request: {}", source)] + CompleteMultipartRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting complete multipart response body: {}", source)] + CompleteMultipartResponseBody { source: HttpError }, + + #[error("Got invalid multipart response: {}", source)] + InvalidMultipartResponse { source: quick_xml::de::DeError }, + + #[error("Error signing blob: {}", source)] + SignBlobRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Got invalid signing blob response: {}", source)] + InvalidSignBlobResponse { source: HttpError }, + + #[error("Got invalid signing blob signature: {}", source)] + InvalidSignBlobSignature { source: base64::DecodeError }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } | Error::Request { source, path } => { + source.error(STORE, path) + } + _ => Self::Generic { + store: STORE, + source: Box::new(err), + }, + } + } +} + +#[derive(Debug)] +pub(crate) struct GoogleCloudStorageConfig { + pub base_url: String, + + pub credentials: GcpCredentialProvider, + + pub signing_credentials: GcpSigningCredentialProvider, + + pub bucket_name: String, + + pub retry_config: RetryConfig, + + pub client_options: ClientOptions, +} + +impl GoogleCloudStorageConfig { + pub(crate) fn new( + base_url: String, + credentials: GcpCredentialProvider, + signing_credentials: GcpSigningCredentialProvider, + bucket_name: String, + retry_config: RetryConfig, + client_options: ClientOptions, + ) -> Self { + Self { + base_url, + credentials, + signing_credentials, + bucket_name, + retry_config, + client_options, + } + } + + pub(crate) fn path_url(&self, path: &Path) -> String { + format!("{}/{}/{}", self.base_url, self.bucket_name, path) + } +} + +/// A builder for a put request allowing customisation of the headers and query string +pub(crate) struct Request<'a> { + path: &'a Path, + config: &'a GoogleCloudStorageConfig, + payload: Option, + builder: HttpRequestBuilder, + idempotent: bool, +} + +impl Request<'_> { + fn header(self, k: &HeaderName, v: &str) -> Self { + let builder = self.builder.header(k, v); + Self { builder, ..self } + } + + fn query(self, query: &T) -> Self { + let builder = self.builder.query(query); + Self { builder, ..self } + } + + fn idempotent(mut self, idempotent: bool) -> Self { + self.idempotent = idempotent; + self + } + + fn with_attributes(self, attributes: Attributes) -> Self { + let mut builder = self.builder; + let mut has_content_type = false; + for (k, v) in &attributes { + builder = match k { + Attribute::CacheControl => builder.header(CACHE_CONTROL, v.as_ref()), + Attribute::ContentDisposition => builder.header(CONTENT_DISPOSITION, v.as_ref()), + Attribute::ContentEncoding => builder.header(CONTENT_ENCODING, v.as_ref()), + Attribute::ContentLanguage => builder.header(CONTENT_LANGUAGE, v.as_ref()), + Attribute::ContentType => { + has_content_type = true; + builder.header(CONTENT_TYPE, v.as_ref()) + } + Attribute::Metadata(k_suffix) => builder.header( + &format!("{}{}", USER_DEFINED_METADATA_HEADER_PREFIX, k_suffix), + v.as_ref(), + ), + }; + } + + if !has_content_type { + let value = self.config.client_options.get_content_type(self.path); + builder = builder.header(CONTENT_TYPE, value.unwrap_or(DEFAULT_CONTENT_TYPE)) + } + Self { builder, ..self } + } + + fn with_payload(self, payload: PutPayload) -> Self { + let content_length = payload.content_length(); + Self { + builder: self.builder.header(CONTENT_LENGTH, content_length), + payload: Some(payload), + ..self + } + } + + fn with_extensions(self, extensions: ::http::Extensions) -> Self { + let builder = self.builder.extensions(extensions); + Self { builder, ..self } + } + + async fn send(self) -> Result { + let credential = self.config.credentials.get_credential().await?; + let resp = self + .builder + .bearer_auth(&credential.bearer) + .retryable(&self.config.retry_config) + .idempotent(self.idempotent) + .payload(self.payload) + .send() + .await + .map_err(|source| { + let path = self.path.as_ref().into(); + Error::Request { source, path } + })?; + Ok(resp) + } + + async fn do_put(self) -> Result { + let response = self.send().await?; + Ok(get_put_result(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?) + } +} + +/// Sign Blob Request Body +#[derive(Debug, Serialize)] +struct SignBlobBody { + /// The payload to sign + payload: String, +} + +/// Sign Blob Response +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct SignBlobResponse { + /// The signature for the payload + signed_blob: String, +} + +#[derive(Debug)] +pub(crate) struct GoogleCloudStorageClient { + config: GoogleCloudStorageConfig, + + client: HttpClient, + + bucket_name_encoded: String, + + // TODO: Hook this up in tests + max_list_results: Option, +} + +impl GoogleCloudStorageClient { + pub(crate) fn new(config: GoogleCloudStorageConfig, client: HttpClient) -> Result { + let bucket_name_encoded = + percent_encode(config.bucket_name.as_bytes(), NON_ALPHANUMERIC).to_string(); + + Ok(Self { + config, + client, + bucket_name_encoded, + max_list_results: None, + }) + } + + pub(crate) fn config(&self) -> &GoogleCloudStorageConfig { + &self.config + } + + async fn get_credential(&self) -> Result> { + self.config.credentials.get_credential().await + } + + /// Create a signature from a string-to-sign using Google Cloud signBlob method. + /// form like: + /// ```plaintext + /// curl -X POST --data-binary @JSON_FILE_NAME \ + /// -H "Authorization: Bearer OAUTH2_TOKEN" \ + /// -H "Content-Type: application/json" \ + /// "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT_EMAIL:signBlob" + /// ``` + /// + /// 'JSON_FILE_NAME' is a file containing the following JSON object: + /// ```plaintext + /// { + /// "payload": "REQUEST_INFORMATION" + /// } + /// ``` + pub(crate) async fn sign_blob( + &self, + string_to_sign: &str, + client_email: &str, + ) -> Result { + let credential = self.get_credential().await?; + let body = SignBlobBody { + payload: BASE64_STANDARD.encode(string_to_sign), + }; + + let url = format!( + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}:signBlob", + client_email + ); + + let response = self + .client + .post(&url) + .bearer_auth(&credential.bearer) + .json(&body) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() + .await + .map_err(|source| Error::SignBlobRequest { source })? + .into_body() + .json::() + .await + .map_err(|source| Error::InvalidSignBlobResponse { source })?; + + let signed_blob = BASE64_STANDARD + .decode(response.signed_blob) + .map_err(|source| Error::InvalidSignBlobSignature { source })?; + + Ok(hex_encode(&signed_blob)) + } + + pub(crate) fn object_url(&self, path: &Path) -> String { + let encoded = utf8_percent_encode(path.as_ref(), NON_ALPHANUMERIC); + format!( + "{}/{}/{}", + self.config.base_url, self.bucket_name_encoded, encoded + ) + } + + /// Perform a put request + /// + /// Returns the new ETag + pub(crate) fn request<'a>(&'a self, method: Method, path: &'a Path) -> Request<'a> { + let builder = self.client.request(method, self.object_url(path)); + + Request { + path, + builder, + payload: None, + config: &self.config, + idempotent: false, + } + } + + pub(crate) async fn put( + &self, + path: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let PutOptions { + mode, + // not supported by GCP + tags: _, + attributes, + extensions, + } = opts; + + let builder = self + .request(Method::PUT, path) + .with_payload(payload) + .with_attributes(attributes) + .with_extensions(extensions); + + let builder = match &mode { + PutMode::Overwrite => builder.idempotent(true), + PutMode::Create => builder.header(&VERSION_MATCH, "0"), + PutMode::Update(v) => { + let etag = v.version.as_ref().ok_or(Error::MissingVersion)?; + builder.header(&VERSION_MATCH, etag) + } + }; + + match (mode, builder.do_put().await) { + (PutMode::Create, Err(crate::Error::Precondition { path, source })) => { + Err(crate::Error::AlreadyExists { path, source }) + } + (_, r) => r, + } + } + + /// Perform a put part request + /// + /// Returns the new [`PartId`] + pub(crate) async fn put_part( + &self, + path: &Path, + upload_id: &MultipartId, + part_idx: usize, + data: PutPayload, + ) -> Result { + let query = &[ + ("partNumber", &format!("{}", part_idx + 1)), + ("uploadId", upload_id), + ]; + let result = self + .request(Method::PUT, path) + .with_payload(data) + .query(query) + .idempotent(true) + .do_put() + .await?; + + Ok(PartId { + content_id: result.e_tag.unwrap(), + }) + } + + /// Initiate a multipart upload + pub(crate) async fn multipart_initiate( + &self, + path: &Path, + opts: PutMultipartOpts, + ) -> Result { + let PutMultipartOpts { + // not supported by GCP + tags: _, + attributes, + extensions, + } = opts; + + let response = self + .request(Method::POST, path) + .with_attributes(attributes) + .with_extensions(extensions) + .header(&CONTENT_LENGTH, "0") + .query(&[("uploads", "")]) + .send() + .await?; + + let data = response + .into_body() + .bytes() + .await + .map_err(|source| Error::PutResponseBody { source })?; + + let result: InitiateMultipartUploadResult = + quick_xml::de::from_reader(data.as_ref().reader()) + .map_err(|source| Error::InvalidPutResponse { source })?; + + Ok(result.upload_id) + } + + /// Cleanup unused parts + pub(crate) async fn multipart_cleanup( + &self, + path: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + self.client + .request(Method::DELETE, &url) + .bearer_auth(&credential.bearer) + .header(CONTENT_TYPE, "application/octet-stream") + .header(CONTENT_LENGTH, "0") + .query(&[("uploadId", multipart_id)]) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| { + let path = path.as_ref().into(); + Error::Request { source, path } + })?; + + Ok(()) + } + + pub(crate) async fn multipart_complete( + &self, + path: &Path, + multipart_id: &MultipartId, + completed_parts: Vec, + ) -> Result { + if completed_parts.is_empty() { + // GCS doesn't allow empty multipart uploads + let result = self + .request(Method::PUT, path) + .idempotent(true) + .do_put() + .await?; + self.multipart_cleanup(path, multipart_id).await?; + return Ok(result); + } + + let upload_id = multipart_id.clone(); + let url = self.object_url(path); + + let upload_info = CompleteMultipartUpload::from(completed_parts); + let credential = self.get_credential().await?; + + let data = quick_xml::se::to_string(&upload_info) + .map_err(|source| Error::InvalidPutRequest { source })? + // We cannot disable the escaping that transforms "/" to ""e;" :( + // https://github.com/tafia/quick-xml/issues/362 + // https://github.com/tafia/quick-xml/issues/350 + .replace(""", "\""); + + let response = self + .client + .request(Method::POST, &url) + .bearer_auth(&credential.bearer) + .query(&[("uploadId", upload_id)]) + .body(data) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() + .await + .map_err(|source| Error::CompleteMultipartRequest { source })?; + + let version = get_version(response.headers(), VERSION_HEADER) + .map_err(|source| Error::Metadata { source })?; + + let data = response + .into_body() + .bytes() + .await + .map_err(|source| Error::CompleteMultipartResponseBody { source })?; + + let response: CompleteMultipartUploadResult = quick_xml::de::from_reader(data.reader()) + .map_err(|source| Error::InvalidMultipartResponse { source })?; + + Ok(PutResult { + e_tag: Some(response.e_tag), + version, + }) + } + + /// Perform a delete request + pub(crate) async fn delete_request(&self, path: &Path) -> Result<()> { + self.request(Method::DELETE, path).send().await?; + Ok(()) + } + + /// Perform a copy request + pub(crate) async fn copy_request( + &self, + from: &Path, + to: &Path, + if_not_exists: bool, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.object_url(to); + + let from = utf8_percent_encode(from.as_ref(), NON_ALPHANUMERIC); + let source = format!("{}/{}", self.bucket_name_encoded, from); + + let mut builder = self + .client + .request(Method::PUT, url) + .header("x-goog-copy-source", source); + + if if_not_exists { + builder = builder.header(&VERSION_MATCH, 0); + } + + builder + .bearer_auth(&credential.bearer) + // Needed if reqwest is compiled with native-tls instead of rustls-tls + // See https://github.com/apache/arrow-rs/pull/3921 + .header(CONTENT_LENGTH, 0) + .retryable(&self.config.retry_config) + .idempotent(!if_not_exists) + .send() + .await + .map_err(|err| match err.status() { + Some(StatusCode::PRECONDITION_FAILED) => crate::Error::AlreadyExists { + source: Box::new(err), + path: to.to_string(), + }, + _ => err.error(STORE, from.to_string()), + })?; + + Ok(()) + } +} + +#[async_trait] +impl GetClient for GoogleCloudStorageClient { + const STORE: &'static str = STORE; + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: true, + last_modified_required: true, + version_header: Some(VERSION_HEADER), + user_defined_metadata_prefix: Some(USER_DEFINED_METADATA_HEADER_PREFIX), + }; + + /// Perform a get request + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + let method = match options.head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut request = self.client.request(method, url); + + if let Some(version) = &options.version { + request = request.query(&[("generation", version)]); + } + + if !credential.bearer.is_empty() { + request = request.bearer_auth(&credential.bearer); + } + + let response = request + .with_get_options(options) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| { + let path = path.as_ref().into(); + Error::GetRequest { source, path } + })?; + + Ok(response) + } +} + +#[async_trait] +impl ListClient for Arc { + /// Perform a list request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + page_token: Option<&str>, + offset: Option<&str>, + ) -> Result<(ListResult, Option)> { + let credential = self.get_credential().await?; + let url = format!("{}/{}", self.config.base_url, self.bucket_name_encoded); + + let mut query = Vec::with_capacity(5); + query.push(("list-type", "2")); + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + if let Some(prefix) = &prefix { + query.push(("prefix", prefix)) + } + + if let Some(page_token) = page_token { + query.push(("continuation-token", page_token)) + } + + if let Some(max_results) = &self.max_list_results { + query.push(("max-keys", max_results)) + } + + if let Some(offset) = offset { + query.push(("start-after", offset)) + } + + let response = self + .client + .request(Method::GET, url) + .query(&query) + .bearer_auth(&credential.bearer) + .send_retry(&self.config.retry_config) + .await + .map_err(|source| Error::ListRequest { source })? + .into_body() + .bytes() + .await + .map_err(|source| Error::ListResponseBody { source })?; + + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidListResponse { source })?; + + let token = response.next_continuation_token.take(); + Ok((response.try_into()?, token)) + } +} diff --git a/src/gcp/credential.rs b/src/gcp/credential.rs new file mode 100644 index 0000000..373c2c2 --- /dev/null +++ b/src/gcp/credential.rs @@ -0,0 +1,891 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::client::GoogleCloudStorageClient; +use crate::client::retry::RetryExt; +use crate::client::token::TemporaryToken; +use crate::client::{HttpClient, HttpError, TokenProvider}; +use crate::gcp::{GcpSigningCredentialProvider, STORE}; +use crate::util::{hex_digest, hex_encode, STRICT_ENCODE_SET}; +use crate::{RetryConfig, StaticCredentialProvider}; +use async_trait::async_trait; +use base64::prelude::BASE64_URL_SAFE_NO_PAD; +use base64::Engine; +use chrono::{DateTime, Utc}; +use futures::TryFutureExt; +use http::{HeaderMap, Method}; +use itertools::Itertools; +use percent_encoding::utf8_percent_encode; +use ring::signature::RsaKeyPair; +use serde::Deserialize; +use std::collections::BTreeMap; +use std::env; +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::info; +use url::Url; + +pub(crate) const DEFAULT_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; + +pub(crate) const DEFAULT_GCS_BASE_URL: &str = "https://storage.googleapis.com"; + +const DEFAULT_GCS_PLAYLOAD_STRING: &str = "UNSIGNED-PAYLOAD"; +const DEFAULT_GCS_SIGN_BLOB_HOST: &str = "storage.googleapis.com"; + +const DEFAULT_METADATA_HOST: &str = "metadata.google.internal"; +const DEFAULT_METADATA_IP: &str = "169.254.169.254"; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Unable to open service account file from {}: {}", path.display(), source)] + OpenCredentials { + source: std::io::Error, + path: PathBuf, + }, + + #[error("Unable to decode service account file: {}", source)] + DecodeCredentials { source: serde_json::Error }, + + #[error("No RSA key found in pem file")] + MissingKey, + + #[error("Invalid RSA key: {}", source)] + InvalidKey { + #[from] + source: ring::error::KeyRejected, + }, + + #[error("Error signing: {}", source)] + Sign { source: ring::error::Unspecified }, + + #[error("Error encoding jwt payload: {}", source)] + Encode { source: serde_json::Error }, + + #[error("Unsupported key encoding: {}", encoding)] + UnsupportedKey { encoding: String }, + + #[error("Error performing token request: {}", source)] + TokenRequest { + source: crate::client::retry::RetryError, + }, + + #[error("Error getting token response body: {}", source)] + TokenResponseBody { source: HttpError }, +} + +impl From for crate::Error { + fn from(value: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(value), + } + } +} + +/// A Google Cloud Storage Credential for signing +#[derive(Debug)] +pub struct GcpSigningCredential { + /// The email of the service account + pub email: String, + + /// An optional RSA private key + /// + /// If provided this will be used to sign the URL, otherwise a call will be made to + /// [`iam.serviceAccounts.signBlob`]. This allows supporting credential sources + /// that don't expose the service account private key, e.g. [IMDS]. + /// + /// [IMDS]: https://cloud.google.com/docs/authentication/get-id-token#metadata-server + /// [`iam.serviceAccounts.signBlob`]: https://cloud.google.com/storage/docs/authentication/creating-signatures + pub private_key: Option, +} + +/// A private RSA key for a service account +#[derive(Debug)] +pub struct ServiceAccountKey(RsaKeyPair); + +impl ServiceAccountKey { + /// Parses a pem-encoded RSA key + pub fn from_pem(encoded: &[u8]) -> Result { + use rustls_pemfile::Item; + use std::io::Cursor; + + let mut cursor = Cursor::new(encoded); + let mut reader = BufReader::new(&mut cursor); + + // Reading from string is infallible + match rustls_pemfile::read_one(&mut reader).unwrap() { + Some(Item::Pkcs8Key(key)) => Self::from_pkcs8(key.secret_pkcs8_der()), + Some(Item::Pkcs1Key(key)) => Self::from_der(key.secret_pkcs1_der()), + _ => Err(Error::MissingKey), + } + } + + /// Parses an unencrypted PKCS#8-encoded RSA private key. + pub fn from_pkcs8(key: &[u8]) -> Result { + Ok(Self(RsaKeyPair::from_pkcs8(key)?)) + } + + /// Parses an unencrypted PKCS#8-encoded RSA private key. + pub fn from_der(key: &[u8]) -> Result { + Ok(Self(RsaKeyPair::from_der(key)?)) + } + + fn sign(&self, string_to_sign: &str) -> Result { + let mut signature = vec![0; self.0.public().modulus_len()]; + self.0 + .sign( + &ring::signature::RSA_PKCS1_SHA256, + &ring::rand::SystemRandom::new(), + string_to_sign.as_bytes(), + &mut signature, + ) + .map_err(|source| Error::Sign { source })?; + + Ok(hex_encode(&signature)) + } +} + +/// A Google Cloud Storage Credential +#[derive(Debug, Eq, PartialEq)] +pub struct GcpCredential { + /// An HTTP bearer token + pub bearer: String, +} + +pub(crate) type Result = std::result::Result; + +#[derive(Debug, Default, serde::Serialize)] +pub(crate) struct JwtHeader<'a> { + /// The type of JWS: it can only be "JWT" here + /// + /// Defined in [RFC7515#4.1.9](https://tools.ietf.org/html/rfc7515#section-4.1.9). + #[serde(skip_serializing_if = "Option::is_none")] + pub typ: Option<&'a str>, + /// The algorithm used + /// + /// Defined in [RFC7515#4.1.1](https://tools.ietf.org/html/rfc7515#section-4.1.1). + pub alg: &'a str, + /// Content type + /// + /// Defined in [RFC7519#5.2](https://tools.ietf.org/html/rfc7519#section-5.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub cty: Option<&'a str>, + /// JSON Key URL + /// + /// Defined in [RFC7515#4.1.2](https://tools.ietf.org/html/rfc7515#section-4.1.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub jku: Option<&'a str>, + /// Key ID + /// + /// Defined in [RFC7515#4.1.4](https://tools.ietf.org/html/rfc7515#section-4.1.4). + #[serde(skip_serializing_if = "Option::is_none")] + pub kid: Option<&'a str>, + /// X.509 URL + /// + /// Defined in [RFC7515#4.1.5](https://tools.ietf.org/html/rfc7515#section-4.1.5). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5u: Option<&'a str>, + /// X.509 certificate thumbprint + /// + /// Defined in [RFC7515#4.1.7](https://tools.ietf.org/html/rfc7515#section-4.1.7). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5t: Option<&'a str>, +} + +#[derive(serde::Serialize)] +struct TokenClaims<'a> { + iss: &'a str, + sub: &'a str, + scope: &'a str, + exp: u64, + iat: u64, +} + +#[derive(serde::Deserialize, Debug)] +struct TokenResponse { + access_token: String, + expires_in: u64, +} + +/// Self-signed JWT (JSON Web Token). +/// +/// # References +/// - +#[derive(Debug)] +pub(crate) struct SelfSignedJwt { + issuer: String, + scope: String, + private_key: ServiceAccountKey, + key_id: String, +} + +impl SelfSignedJwt { + /// Create a new [`SelfSignedJwt`] + pub(crate) fn new( + key_id: String, + issuer: String, + private_key: ServiceAccountKey, + scope: String, + ) -> Result { + Ok(Self { + issuer, + scope, + private_key, + key_id, + }) + } +} + +#[async_trait] +impl TokenProvider for SelfSignedJwt { + type Credential = GcpCredential; + + /// Fetch a fresh token + async fn fetch_token( + &self, + _client: &HttpClient, + _retry: &RetryConfig, + ) -> crate::Result>> { + let now = seconds_since_epoch(); + let exp = now + 3600; + + let claims = TokenClaims { + iss: &self.issuer, + sub: &self.issuer, + scope: &self.scope, + iat: now, + exp, + }; + + let jwt_header = b64_encode_obj(&JwtHeader { + alg: "RS256", + typ: Some("JWT"), + kid: Some(&self.key_id), + ..Default::default() + })?; + + let claim_str = b64_encode_obj(&claims)?; + let message = [jwt_header.as_ref(), claim_str.as_ref()].join("."); + let mut sig_bytes = vec![0; self.private_key.0.public().modulus_len()]; + self.private_key + .0 + .sign( + &ring::signature::RSA_PKCS1_SHA256, + &ring::rand::SystemRandom::new(), + message.as_bytes(), + &mut sig_bytes, + ) + .map_err(|source| Error::Sign { source })?; + + let signature = BASE64_URL_SAFE_NO_PAD.encode(sig_bytes); + let bearer = [message, signature].join("."); + + Ok(TemporaryToken { + token: Arc::new(GcpCredential { bearer }), + expiry: Some(Instant::now() + Duration::from_secs(3600)), + }) + } +} + +fn read_credentials_file(service_account_path: impl AsRef) -> Result +where + T: serde::de::DeserializeOwned, +{ + let file = File::open(&service_account_path).map_err(|source| { + let path = service_account_path.as_ref().to_owned(); + Error::OpenCredentials { source, path } + })?; + let reader = BufReader::new(file); + serde_json::from_reader(reader).map_err(|source| Error::DecodeCredentials { source }) +} + +/// A deserialized `service-account-********.json`-file. +#[derive(serde::Deserialize, Debug, Clone)] +pub(crate) struct ServiceAccountCredentials { + /// The private key in RSA format. + pub private_key: String, + + /// The private key ID + pub private_key_id: String, + + /// The email address associated with the service account. + pub client_email: String, + + /// Base URL for GCS + #[serde(default)] + pub gcs_base_url: Option, + + /// Disable oauth and use empty tokens. + #[serde(default)] + pub disable_oauth: bool, +} + +impl ServiceAccountCredentials { + /// Create a new [`ServiceAccountCredentials`] from a file. + pub(crate) fn from_file>(path: P) -> Result { + read_credentials_file(path) + } + + /// Create a new [`ServiceAccountCredentials`] from a string. + pub(crate) fn from_key(key: &str) -> Result { + serde_json::from_str(key).map_err(|source| Error::DecodeCredentials { source }) + } + + /// Create a [`SelfSignedJwt`] from this credentials struct. + /// + /// We use a scope of [`DEFAULT_SCOPE`] as opposed to an audience + /// as GCS appears to not support audience + /// + /// # References + /// - + /// - + pub(crate) fn token_provider(self) -> crate::Result { + Ok(SelfSignedJwt::new( + self.private_key_id, + self.client_email, + ServiceAccountKey::from_pem(self.private_key.as_bytes())?, + DEFAULT_SCOPE.to_string(), + )?) + } + + pub(crate) fn signing_credentials(self) -> crate::Result { + Ok(Arc::new(StaticCredentialProvider::new( + GcpSigningCredential { + email: self.client_email, + private_key: Some(ServiceAccountKey::from_pem(self.private_key.as_bytes())?), + }, + ))) + } +} + +/// Returns the number of seconds since unix epoch +fn seconds_since_epoch() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() +} + +fn b64_encode_obj(obj: &T) -> Result { + let string = serde_json::to_string(obj).map_err(|source| Error::Encode { source })?; + Ok(BASE64_URL_SAFE_NO_PAD.encode(string)) +} + +/// A provider that uses the Google Cloud Platform metadata server to fetch a token. +/// +/// +#[derive(Debug, Default)] +pub(crate) struct InstanceCredentialProvider {} + +/// Make a request to the metadata server to fetch a token, using a a given hostname. +async fn make_metadata_request( + client: &HttpClient, + hostname: &str, + retry: &RetryConfig, +) -> crate::Result { + let url = + format!("http://{hostname}/computeMetadata/v1/instance/service-accounts/default/token"); + let response: TokenResponse = client + .get(url) + .header("Metadata-Flavor", "Google") + .query(&[("audience", "https://www.googleapis.com/oauth2/v4/token")]) + .send_retry(retry) + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .json() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + Ok(response) +} + +#[async_trait] +impl TokenProvider for InstanceCredentialProvider { + type Credential = GcpCredential; + + /// Fetch a token from the metadata server. + /// Since the connection is local we need to enable http access and don't actually use the client object passed in. + /// Respects the `GCE_METADATA_HOST`, `GCE_METADATA_ROOT`, and `GCE_METADATA_IP` + /// environment variables. + /// + /// References: + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let metadata_host = if let Ok(host) = env::var("GCE_METADATA_HOST") { + host + } else if let Ok(host) = env::var("GCE_METADATA_ROOT") { + host + } else { + DEFAULT_METADATA_HOST.to_string() + }; + let metadata_ip = if let Ok(ip) = env::var("GCE_METADATA_IP") { + ip + } else { + DEFAULT_METADATA_IP.to_string() + }; + + info!("fetching token from metadata server"); + let response = make_metadata_request(client, &metadata_host, retry) + .or_else(|_| make_metadata_request(client, &metadata_ip, retry)) + .await?; + + let token = TemporaryToken { + token: Arc::new(GcpCredential { + bearer: response.access_token, + }), + expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + }; + Ok(token) + } +} + +/// Make a request to the metadata server to fetch the client email, using a given hostname. +async fn make_metadata_request_for_email( + client: &HttpClient, + hostname: &str, + retry: &RetryConfig, +) -> crate::Result { + let url = + format!("http://{hostname}/computeMetadata/v1/instance/service-accounts/default/email",); + let response = client + .get(url) + .header("Metadata-Flavor", "Google") + .send_retry(retry) + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .text() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + Ok(response) +} + +/// A provider that uses the Google Cloud Platform metadata server to fetch a email for signing. +/// +/// +#[derive(Debug, Default)] +pub(crate) struct InstanceSigningCredentialProvider {} + +#[async_trait] +impl TokenProvider for InstanceSigningCredentialProvider { + type Credential = GcpSigningCredential; + + /// Fetch a token from the metadata server. + /// Since the connection is local we need to enable http access and don't actually use the client object passed in. + /// Respects the `GCE_METADATA_HOST`, `GCE_METADATA_ROOT`, and `GCE_METADATA_IP` + /// environment variables. + /// + /// References: + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let metadata_host = if let Ok(host) = env::var("GCE_METADATA_HOST") { + host + } else if let Ok(host) = env::var("GCE_METADATA_ROOT") { + host + } else { + DEFAULT_METADATA_HOST.to_string() + }; + + let metadata_ip = if let Ok(ip) = env::var("GCE_METADATA_IP") { + ip + } else { + DEFAULT_METADATA_IP.to_string() + }; + + info!("fetching token from metadata server"); + + let email = make_metadata_request_for_email(client, &metadata_host, retry) + .or_else(|_| make_metadata_request_for_email(client, &metadata_ip, retry)) + .await?; + + let token = TemporaryToken { + token: Arc::new(GcpSigningCredential { + email, + private_key: None, + }), + expiry: None, + }; + Ok(token) + } +} + +/// A deserialized `application_default_credentials.json`-file. +/// +/// # References +/// - +/// - +#[derive(serde::Deserialize, Clone)] +#[serde(tag = "type")] +pub(crate) enum ApplicationDefaultCredentials { + /// Service Account. + /// + /// # References + /// - + #[serde(rename = "service_account")] + ServiceAccount(ServiceAccountCredentials), + /// Authorized user via "gcloud CLI Integration". + /// + /// # References + /// - + #[serde(rename = "authorized_user")] + AuthorizedUser(AuthorizedUserCredentials), +} + +impl ApplicationDefaultCredentials { + const CREDENTIALS_PATH: &'static str = if cfg!(windows) { + "gcloud/application_default_credentials.json" + } else { + ".config/gcloud/application_default_credentials.json" + }; + + // Create a new application default credential in the following situations: + // 1. a file is passed in and the type matches. + // 2. without argument if the well-known configuration file is present. + pub(crate) fn read(path: Option<&str>) -> Result, Error> { + if let Some(path) = path { + return read_credentials_file::(path).map(Some); + } + + let home_var = if cfg!(windows) { "APPDATA" } else { "HOME" }; + if let Some(home) = env::var_os(home_var) { + let path = Path::new(&home).join(Self::CREDENTIALS_PATH); + + // It's expected for this file to not exist unless it has been explicitly configured by the user. + if path.exists() { + return read_credentials_file::(path).map(Some); + } + } + Ok(None) + } +} + +const DEFAULT_TOKEN_GCP_URI: &str = "https://accounts.google.com/o/oauth2/token"; + +/// +#[derive(Debug, Deserialize, Clone)] +pub(crate) struct AuthorizedUserCredentials { + client_id: String, + client_secret: String, + refresh_token: String, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct AuthorizedUserSigningCredentials { + credential: AuthorizedUserCredentials, +} + +/// +#[derive(Debug, Deserialize)] +struct EmailResponse { + email: String, +} + +impl AuthorizedUserSigningCredentials { + pub(crate) fn from(credential: AuthorizedUserCredentials) -> crate::Result { + Ok(Self { credential }) + } + + async fn client_email( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result { + let response = client + .get("https://oauth2.googleapis.com/tokeninfo") + .query(&[("access_token", &self.credential.refresh_token)]) + .send_retry(retry) + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .json::() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + + Ok(response.email) + } +} + +#[async_trait] +impl TokenProvider for AuthorizedUserSigningCredentials { + type Credential = GcpSigningCredential; + + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let email = self.client_email(client, retry).await?; + + Ok(TemporaryToken { + token: Arc::new(GcpSigningCredential { + email, + private_key: None, + }), + expiry: None, + }) + } +} + +#[async_trait] +impl TokenProvider for AuthorizedUserCredentials { + type Credential = GcpCredential; + + async fn fetch_token( + &self, + client: &HttpClient, + retry: &RetryConfig, + ) -> crate::Result>> { + let response = client + .post(DEFAULT_TOKEN_GCP_URI) + .form([ + ("grant_type", "refresh_token"), + ("client_id", &self.client_id), + ("client_secret", &self.client_secret), + ("refresh_token", &self.refresh_token), + ]) + .retryable(retry) + .idempotent(true) + .send() + .await + .map_err(|source| Error::TokenRequest { source })? + .into_body() + .json::() + .await + .map_err(|source| Error::TokenResponseBody { source })?; + + Ok(TemporaryToken { + token: Arc::new(GcpCredential { + bearer: response.access_token, + }), + expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + }) + } +} + +/// Trim whitespace from header values +fn trim_header_value(value: &str) -> String { + let mut ret = value.to_string(); + ret.retain(|c| !c.is_whitespace()); + ret +} + +/// A Google Cloud Storage Authorizer for generating signed URL using [Google SigV4] +/// +/// [Google SigV4]: https://cloud.google.com/storage/docs/access-control/signed-urls +#[derive(Debug)] +pub(crate) struct GCSAuthorizer { + date: Option>, + credential: Arc, +} + +impl GCSAuthorizer { + /// Create a new [`GCSAuthorizer`] + pub(crate) fn new(credential: Arc) -> Self { + Self { + date: None, + credential, + } + } + + pub(crate) async fn sign( + &self, + method: Method, + url: &mut Url, + expires_in: Duration, + client: &GoogleCloudStorageClient, + ) -> crate::Result<()> { + let email = &self.credential.email; + let date = self.date.unwrap_or_else(Utc::now); + let scope = self.scope(date); + let credential_with_scope = format!("{}/{}", email, scope); + + let mut headers = HeaderMap::new(); + headers.insert("host", DEFAULT_GCS_SIGN_BLOB_HOST.parse().unwrap()); + + let (_, signed_headers) = Self::canonicalize_headers(&headers); + + url.query_pairs_mut() + .append_pair("X-Goog-Algorithm", "GOOG4-RSA-SHA256") + .append_pair("X-Goog-Credential", &credential_with_scope) + .append_pair("X-Goog-Date", &date.format("%Y%m%dT%H%M%SZ").to_string()) + .append_pair("X-Goog-Expires", &expires_in.as_secs().to_string()) + .append_pair("X-Goog-SignedHeaders", &signed_headers); + + let string_to_sign = self.string_to_sign(date, &method, url, &headers); + let signature = match &self.credential.private_key { + Some(key) => key.sign(&string_to_sign)?, + None => client.sign_blob(&string_to_sign, email).await?, + }; + + url.query_pairs_mut() + .append_pair("X-Goog-Signature", &signature); + Ok(()) + } + + /// Get scope for the request + /// + /// + fn scope(&self, date: DateTime) -> String { + format!("{}/auto/storage/goog4_request", date.format("%Y%m%d"),) + } + + /// Canonicalizes query parameters into the GCP canonical form + /// form like: + ///```plaintext + ///HTTP_VERB + ///PATH_TO_RESOURCE + ///CANONICAL_QUERY_STRING + ///CANONICAL_HEADERS + /// + ///SIGNED_HEADERS + ///PAYLOAD + ///``` + /// + /// + fn canonicalize_request(url: &Url, method: &Method, headers: &HeaderMap) -> String { + let verb = method.as_str(); + let path = url.path(); + let query = Self::canonicalize_query(url); + let (canonical_headers, signed_headers) = Self::canonicalize_headers(headers); + + format!( + "{}\n{}\n{}\n{}\n\n{}\n{}", + verb, path, query, canonical_headers, signed_headers, DEFAULT_GCS_PLAYLOAD_STRING + ) + } + + /// Canonicalizes query parameters into the GCP canonical form + /// form like `max-keys=2&prefix=object` + /// + /// + fn canonicalize_query(url: &Url) -> String { + url.query_pairs() + .sorted_unstable_by(|a, b| a.0.cmp(&b.0)) + .map(|(k, v)| { + format!( + "{}={}", + utf8_percent_encode(k.as_ref(), &STRICT_ENCODE_SET), + utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET) + ) + }) + .join("&") + } + + /// Canonicalizes header into the GCP canonical form + /// + /// + fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) { + //FIXME add error handling for invalid header values + let mut headers = BTreeMap::>::new(); + for (k, v) in header_map { + headers + .entry(k.as_str().to_lowercase()) + .or_default() + .push(std::str::from_utf8(v.as_bytes()).unwrap()); + } + + let canonicalize_headers = headers + .iter() + .map(|(k, v)| { + format!( + "{}:{}", + k.trim(), + v.iter().map(|v| trim_header_value(v)).join(",") + ) + }) + .join("\n"); + + let signed_headers = headers.keys().join(";"); + + (canonicalize_headers, signed_headers) + } + + ///construct the string to sign + ///form like: + ///```plaintext + ///SIGNING_ALGORITHM + ///ACTIVE_DATETIME + ///CREDENTIAL_SCOPE + ///HASHED_CANONICAL_REQUEST + ///``` + ///`ACTIVE_DATETIME` format:`YYYYMMDD'T'HHMMSS'Z'` + /// + pub(crate) fn string_to_sign( + &self, + date: DateTime, + request_method: &Method, + url: &Url, + headers: &HeaderMap, + ) -> String { + let canonical_request = Self::canonicalize_request(url, request_method, headers); + let hashed_canonical_req = hex_digest(canonical_request.as_bytes()); + let scope = self.scope(date); + + format!( + "{}\n{}\n{}\n{}", + "GOOG4-RSA-SHA256", + date.format("%Y%m%dT%H%M%SZ"), + scope, + hashed_canonical_req + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_canonicalize_headers() { + let mut input_header = HeaderMap::new(); + input_header.insert("content-type", "text/plain".parse().unwrap()); + input_header.insert("host", "storage.googleapis.com".parse().unwrap()); + input_header.insert("x-goog-meta-reviewer", "jane".parse().unwrap()); + input_header.append("x-goog-meta-reviewer", "john".parse().unwrap()); + assert_eq!( + GCSAuthorizer::canonicalize_headers(&input_header), + ( + "content-type:text/plain +host:storage.googleapis.com +x-goog-meta-reviewer:jane,john" + .into(), + "content-type;host;x-goog-meta-reviewer".to_string() + ) + ); + } + + #[test] + fn test_canonicalize_query() { + let mut url = Url::parse("https://storage.googleapis.com/bucket/object").unwrap(); + url.query_pairs_mut() + .append_pair("max-keys", "2") + .append_pair("prefix", "object"); + assert_eq!( + GCSAuthorizer::canonicalize_query(&url), + "max-keys=2&prefix=object".to_string() + ); + } +} diff --git a/src/gcp/mod.rs b/src/gcp/mod.rs new file mode 100644 index 0000000..5f8c67d --- /dev/null +++ b/src/gcp/mod.rs @@ -0,0 +1,422 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for Google Cloud Storage +//! +//! ## Multipart uploads +//! +//! [Multipart uploads](https://cloud.google.com/storage/docs/multipart-uploads) +//! can be initiated with the [ObjectStore::put_multipart] method. If neither +//! [`MultipartUpload::complete`] nor [`MultipartUpload::abort`] is invoked, you may +//! have parts uploaded to GCS but not used, that you will be charged for. It is recommended +//! you configure a [lifecycle rule] to abort incomplete multipart uploads after a certain +//! period of time to avoid being charged for storing partial uploads. +//! +//! ## Using HTTP/2 +//! +//! Google Cloud Storage supports both HTTP/2 and HTTP/1. HTTP/1 is used by default +//! because it allows much higher throughput in our benchmarks (see +//! [#5194](https://github.com/apache/arrow-rs/issues/5194)). HTTP/2 can be +//! enabled by setting [crate::ClientConfigKey::Http1Only] to false. +//! +//! [lifecycle rule]: https://cloud.google.com/storage/docs/lifecycle#abort-mpu +use std::sync::Arc; +use std::time::Duration; + +use crate::client::CredentialProvider; +use crate::gcp::credential::GCSAuthorizer; +use crate::signer::Signer; +use crate::{ + multipart::PartId, path::Path, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, + ObjectMeta, ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, + UploadPart, +}; +use async_trait::async_trait; +use client::GoogleCloudStorageClient; +use futures::stream::BoxStream; +use http::Method; +use url::Url; + +use crate::client::get::GetClientExt; +use crate::client::list::ListClientExt; +use crate::client::parts::Parts; +use crate::multipart::MultipartStore; +pub use builder::{GoogleCloudStorageBuilder, GoogleConfigKey}; +pub use credential::{GcpCredential, GcpSigningCredential, ServiceAccountKey}; + +mod builder; +mod client; +mod credential; + +const STORE: &str = "GCS"; + +/// [`CredentialProvider`] for [`GoogleCloudStorage`] +pub type GcpCredentialProvider = Arc>; + +/// [`GcpSigningCredential`] for [`GoogleCloudStorage`] +pub type GcpSigningCredentialProvider = + Arc>; + +/// Interface for [Google Cloud Storage](https://cloud.google.com/storage/). +#[derive(Debug, Clone)] +pub struct GoogleCloudStorage { + client: Arc, +} + +impl std::fmt::Display for GoogleCloudStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GoogleCloudStorage({})", + self.client.config().bucket_name + ) + } +} + +impl GoogleCloudStorage { + /// Returns the [`GcpCredentialProvider`] used by [`GoogleCloudStorage`] + pub fn credentials(&self) -> &GcpCredentialProvider { + &self.client.config().credentials + } + + /// Returns the [`GcpSigningCredentialProvider`] used by [`GoogleCloudStorage`] + pub fn signing_credentials(&self) -> &GcpSigningCredentialProvider { + &self.client.config().signing_credentials + } +} + +#[derive(Debug)] +struct GCSMultipartUpload { + state: Arc, + part_idx: usize, +} + +#[derive(Debug)] +struct UploadState { + client: Arc, + path: Path, + multipart_id: MultipartId, + parts: Parts, +} + +#[async_trait] +impl MultipartUpload for GCSMultipartUpload { + fn put_part(&mut self, payload: PutPayload) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state + .client + .put_part(&state.path, &state.multipart_id, idx, payload) + .await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .multipart_complete(&self.state.path, &self.state.multipart_id, parts) + .await + } + + async fn abort(&mut self) -> Result<()> { + self.state + .client + .multipart_cleanup(&self.state.path, &self.state.multipart_id) + .await + } +} + +#[async_trait] +impl ObjectStore for GoogleCloudStorage { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.client.put(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + let upload_id = self.client.multipart_initiate(location, opts).await?; + + Ok(Box::new(GCSMultipartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + path: location.clone(), + multipart_id: upload_id.clone(), + parts: Default::default(), + }), + })) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.client.get_opts(location, options).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete_request(location).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.client.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + self.client.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.client.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, false).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, true).await + } +} + +#[async_trait] +impl MultipartStore for GoogleCloudStorage { + async fn create_multipart(&self, path: &Path) -> Result { + self.client + .multipart_initiate(path, PutMultipartOpts::default()) + .await + } + + async fn put_part( + &self, + path: &Path, + id: &MultipartId, + part_idx: usize, + payload: PutPayload, + ) -> Result { + self.client.put_part(path, id, part_idx, payload).await + } + + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + parts: Vec, + ) -> Result { + self.client.multipart_complete(path, id, parts).await + } + + async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()> { + self.client.multipart_cleanup(path, id).await + } +} + +#[async_trait] +impl Signer for GoogleCloudStorage { + async fn signed_url(&self, method: Method, path: &Path, expires_in: Duration) -> Result { + if expires_in.as_secs() > 604800 { + return Err(crate::Error::Generic { + store: STORE, + source: "Expiration Time can't be longer than 604800 seconds (7 days).".into(), + }); + } + + let config = self.client.config(); + let path_url = config.path_url(path); + let mut url = Url::parse(&path_url).map_err(|e| crate::Error::Generic { + store: STORE, + source: format!("Unable to parse url {path_url}: {e}").into(), + })?; + + let signing_credentials = self.signing_credentials().get_credential().await?; + let authorizer = GCSAuthorizer::new(signing_credentials); + + authorizer + .sign(method, &mut url, expires_in, &self.client) + .await?; + + Ok(url) + } +} + +#[cfg(test)] +mod test { + + use credential::DEFAULT_GCS_BASE_URL; + + use crate::integration::*; + use crate::tests::*; + + use super::*; + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + #[tokio::test] + async fn gcs_test() { + maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + if integration.client.config().base_url == DEFAULT_GCS_BASE_URL { + // Fake GCS server doesn't currently honor ifGenerationMatch + // https://github.com/fsouza/fake-gcs-server/issues/994 + copy_if_not_exists(&integration).await; + // Fake GCS server does not yet implement XML Multipart uploads + // https://github.com/fsouza/fake-gcs-server/issues/852 + stream_get(&integration).await; + multipart(&integration, &integration).await; + multipart_race_condition(&integration, true).await; + multipart_out_of_order(&integration).await; + // Fake GCS server doesn't currently honor preconditions + get_opts(&integration).await; + put_opts(&integration, true).await; + // Fake GCS server doesn't currently support attributes + put_get_attributes(&integration).await; + } + } + + #[tokio::test] + #[ignore] + async fn gcs_test_sign() { + maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); + + let client = reqwest::Client::new(); + + let path = Path::from("test_sign"); + let url = integration + .signed_url(Method::PUT, &path, Duration::from_secs(3600)) + .await + .unwrap(); + println!("PUT {url}"); + + let resp = client.put(url).body("data").send().await.unwrap(); + resp.error_for_status().unwrap(); + + let url = integration + .signed_url(Method::GET, &path, Duration::from_secs(3600)) + .await + .unwrap(); + println!("GET {url}"); + + let resp = client.get(url).send().await.unwrap(); + let resp = resp.error_for_status().unwrap(); + let data = resp.bytes().await.unwrap(); + assert_eq!(data.as_ref(), b"data"); + } + + #[tokio::test] + async fn gcs_test_get_nonexistent_location() { + maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.get(&location).await.unwrap_err(); + + assert!( + matches!(err, crate::Error::NotFound { .. }), + "unexpected error type: {err}" + ); + } + + #[tokio::test] + async fn gcs_test_get_nonexistent_bucket() { + maybe_skip_integration!(); + let config = GoogleCloudStorageBuilder::from_env(); + let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + + assert!( + matches!(err, crate::Error::NotFound { .. }), + "unexpected error type: {err}" + ); + } + + #[tokio::test] + async fn gcs_test_delete_nonexistent_location() { + maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!( + matches!(err, crate::Error::NotFound { .. }), + "unexpected error type: {err}" + ); + } + + #[tokio::test] + async fn gcs_test_delete_nonexistent_bucket() { + maybe_skip_integration!(); + let config = GoogleCloudStorageBuilder::from_env(); + let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!( + matches!(err, crate::Error::NotFound { .. }), + "unexpected error type: {err}" + ); + } + + #[tokio::test] + async fn gcs_test_put_nonexistent_bucket() { + maybe_skip_integration!(); + let config = GoogleCloudStorageBuilder::from_env(); + let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + let data = PutPayload::from("arbitrary data"); + + let err = integration + .put(&location, data) + .await + .unwrap_err() + .to_string(); + assert!( + err.contains("Server returned non-2xx status code: 404 Not Found"), + "{}", + err + ) + } +} diff --git a/src/http/client.rs b/src/http/client.rs new file mode 100644 index 0000000..652d326 --- /dev/null +++ b/src/http/client.rs @@ -0,0 +1,478 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::get::GetClient; +use crate::client::header::HeaderConfig; +use crate::client::retry::{self, RetryConfig, RetryExt}; +use crate::client::{GetOptionsExt, HttpClient, HttpError, HttpResponse}; +use crate::path::{Path, DELIMITER}; +use crate::util::deserialize_rfc1123; +use crate::{Attribute, Attributes, ClientOptions, GetOptions, ObjectMeta, PutPayload, Result}; +use async_trait::async_trait; +use bytes::Buf; +use chrono::{DateTime, Utc}; +use http::header::{ + CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, + CONTENT_TYPE, +}; +use percent_encoding::percent_decode_str; +use reqwest::{Method, StatusCode}; +use serde::Deserialize; +use url::Url; + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Request error: {}", source)] + Request { source: retry::RetryError }, + + #[error("Request error: {}", source)] + Reqwest { source: HttpError }, + + #[error("Range request not supported by {}", href)] + RangeNotSupported { href: String }, + + #[error("Error decoding PROPFIND response: {}", source)] + InvalidPropFind { source: quick_xml::de::DeError }, + + #[error("Missing content size for {}", href)] + MissingSize { href: String }, + + #[error("Error getting properties of \"{}\" got \"{}\"", href, status)] + PropStatus { href: String, status: String }, + + #[error("Failed to parse href \"{}\": {}", href, source)] + InvalidHref { + href: String, + source: url::ParseError, + }, + + #[error("Path \"{}\" contained non-unicode characters: {}", path, source)] + NonUnicode { + path: String, + source: std::str::Utf8Error, + }, + + #[error("Encountered invalid path \"{}\": {}", path, source)] + InvalidPath { + path: String, + source: crate::path::Error, + }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + Self::Generic { + store: "HTTP", + source: Box::new(err), + } + } +} + +/// Internal client for HttpStore +#[derive(Debug)] +pub(crate) struct Client { + url: Url, + client: HttpClient, + retry_config: RetryConfig, + client_options: ClientOptions, +} + +impl Client { + pub(crate) fn new( + url: Url, + client: HttpClient, + client_options: ClientOptions, + retry_config: RetryConfig, + ) -> Self { + Self { + url, + retry_config, + client_options, + client, + } + } + + pub(crate) fn base_url(&self) -> &Url { + &self.url + } + + fn path_url(&self, location: &Path) -> String { + let mut url = self.url.clone(); + url.path_segments_mut().unwrap().extend(location.parts()); + url.to_string() + } + + /// Create a directory with `path` using MKCOL + async fn make_directory(&self, path: &str) -> Result<(), Error> { + let method = Method::from_bytes(b"MKCOL").unwrap(); + let mut url = self.url.clone(); + url.path_segments_mut() + .unwrap() + .extend(path.split(DELIMITER)); + + self.client + .request(method, String::from(url)) + .send_retry(&self.retry_config) + .await + .map_err(|source| Error::Request { source })?; + + Ok(()) + } + + /// Recursively create parent directories + async fn create_parent_directories(&self, location: &Path) -> Result<()> { + let mut stack = vec![]; + + // Walk backwards until a request succeeds + let mut last_prefix = location.as_ref(); + while let Some((prefix, _)) = last_prefix.rsplit_once(DELIMITER) { + last_prefix = prefix; + + match self.make_directory(prefix).await { + Ok(_) => break, + Err(Error::Request { source }) + if matches!(source.status(), Some(StatusCode::CONFLICT)) => + { + // Need to create parent + stack.push(prefix) + } + Err(e) => return Err(e.into()), + } + } + + // Retry the failed requests, which should now succeed + for prefix in stack.into_iter().rev() { + self.make_directory(prefix).await?; + } + + Ok(()) + } + + pub(crate) async fn put( + &self, + location: &Path, + payload: PutPayload, + attributes: Attributes, + ) -> Result { + let mut retry = false; + loop { + let url = self.path_url(location); + let mut builder = self.client.put(url); + + let mut has_content_type = false; + for (k, v) in &attributes { + builder = match k { + Attribute::CacheControl => builder.header(CACHE_CONTROL, v.as_ref()), + Attribute::ContentDisposition => { + builder.header(CONTENT_DISPOSITION, v.as_ref()) + } + Attribute::ContentEncoding => builder.header(CONTENT_ENCODING, v.as_ref()), + Attribute::ContentLanguage => builder.header(CONTENT_LANGUAGE, v.as_ref()), + Attribute::ContentType => { + has_content_type = true; + builder.header(CONTENT_TYPE, v.as_ref()) + } + // Ignore metadata attributes + Attribute::Metadata(_) => builder, + }; + } + + if !has_content_type { + if let Some(value) = self.client_options.get_content_type(location) { + builder = builder.header(CONTENT_TYPE, value); + } + } + + let resp = builder + .header(CONTENT_LENGTH, payload.content_length()) + .retryable(&self.retry_config) + .idempotent(true) + .payload(Some(payload.clone())) + .send() + .await; + + match resp { + Ok(response) => return Ok(response), + Err(source) => match source.status() { + // Some implementations return 404 instead of 409 + Some(StatusCode::CONFLICT | StatusCode::NOT_FOUND) if !retry => { + retry = true; + self.create_parent_directories(location).await? + } + _ => return Err(Error::Request { source }.into()), + }, + } + } + } + + pub(crate) async fn list(&self, location: Option<&Path>, depth: &str) -> Result { + let url = location + .map(|path| self.path_url(path)) + .unwrap_or_else(|| self.url.to_string()); + + let method = Method::from_bytes(b"PROPFIND").unwrap(); + let result = self + .client + .request(method, url) + .header("Depth", depth) + .retryable(&self.retry_config) + .idempotent(true) + .send() + .await; + + let response = match result { + Ok(result) => result + .into_body() + .bytes() + .await + .map_err(|source| Error::Reqwest { source })?, + Err(e) if matches!(e.status(), Some(StatusCode::NOT_FOUND)) => { + return match depth { + "0" => { + let path = location.map(|x| x.as_ref()).unwrap_or(""); + Err(crate::Error::NotFound { + path: path.to_string(), + source: Box::new(e), + }) + } + _ => { + // If prefix not found, return empty result set + Ok(Default::default()) + } + }; + } + Err(source) => return Err(Error::Request { source }.into()), + }; + + let status = quick_xml::de::from_reader(response.reader()) + .map_err(|source| Error::InvalidPropFind { source })?; + + Ok(status) + } + + pub(crate) async fn delete(&self, path: &Path) -> Result<()> { + let url = self.path_url(path); + self.client + .delete(url) + .send_retry(&self.retry_config) + .await + .map_err(|source| match source.status() { + Some(StatusCode::NOT_FOUND) => crate::Error::NotFound { + source: Box::new(source), + path: path.to_string(), + }, + _ => Error::Request { source }.into(), + })?; + Ok(()) + } + + pub(crate) async fn copy(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> { + let mut retry = false; + loop { + let method = Method::from_bytes(b"COPY").unwrap(); + + let mut builder = self + .client + .request(method, self.path_url(from)) + .header("Destination", self.path_url(to).as_str()); + + if !overwrite { + // While the Overwrite header appears to duplicate + // the functionality of the If-Match: * header of HTTP/1.1, If-Match + // applies only to the Request-URI, and not to the Destination of a COPY + // or MOVE. + builder = builder.header("Overwrite", "F"); + } + + return match builder.send_retry(&self.retry_config).await { + Ok(_) => Ok(()), + Err(source) => Err(match source.status() { + Some(StatusCode::PRECONDITION_FAILED) if !overwrite => { + crate::Error::AlreadyExists { + path: to.to_string(), + source: Box::new(source), + } + } + // Some implementations return 404 instead of 409 + Some(StatusCode::CONFLICT | StatusCode::NOT_FOUND) if !retry => { + retry = true; + self.create_parent_directories(to).await?; + continue; + } + _ => Error::Request { source }.into(), + }), + }; + } + } +} + +#[async_trait] +impl GetClient for Client { + const STORE: &'static str = "HTTP"; + + /// Override the [`HeaderConfig`] to be less strict to support a + /// broader range of HTTP servers (#4831) + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: false, + last_modified_required: false, + version_header: None, + user_defined_metadata_prefix: None, + }; + + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let url = self.path_url(path); + let method = match options.head { + true => Method::HEAD, + false => Method::GET, + }; + let has_range = options.range.is_some(); + let builder = self.client.request(method, url); + + let res = builder + .with_get_options(options) + .send_retry(&self.retry_config) + .await + .map_err(|source| match source.status() { + // Some stores return METHOD_NOT_ALLOWED for get on directories + Some(StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED) => { + crate::Error::NotFound { + source: Box::new(source), + path: path.to_string(), + } + } + _ => Error::Request { source }.into(), + })?; + + // We expect a 206 Partial Content response if a range was requested + // a 200 OK response would indicate the server did not fulfill the request + if has_range && res.status() != StatusCode::PARTIAL_CONTENT { + return Err(crate::Error::NotSupported { + source: Box::new(Error::RangeNotSupported { + href: path.to_string(), + }), + }); + } + + Ok(res) + } +} + +/// The response returned by a PROPFIND request, i.e. list +#[derive(Deserialize, Default)] +pub(crate) struct MultiStatus { + pub response: Vec, +} + +#[derive(Deserialize)] +pub(crate) struct MultiStatusResponse { + href: String, + #[serde(rename = "propstat")] + prop_stat: PropStat, +} + +impl MultiStatusResponse { + /// Returns an error if this response is not OK + pub(crate) fn check_ok(&self) -> Result<()> { + match self.prop_stat.status.contains("200 OK") { + true => Ok(()), + false => Err(Error::PropStatus { + href: self.href.clone(), + status: self.prop_stat.status.clone(), + } + .into()), + } + } + + /// Returns the resolved path of this element relative to `base_url` + pub(crate) fn path(&self, base_url: &Url) -> Result { + let url = Url::options() + .base_url(Some(base_url)) + .parse(&self.href) + .map_err(|source| Error::InvalidHref { + href: self.href.clone(), + source, + })?; + + // Reverse any percent encoding + let path = percent_decode_str(url.path()) + .decode_utf8() + .map_err(|source| Error::NonUnicode { + path: url.path().into(), + source, + })?; + + Ok(Path::parse(path.as_ref()).map_err(|source| { + let path = path.into(); + Error::InvalidPath { path, source } + })?) + } + + fn size(&self) -> Result { + let size = self + .prop_stat + .prop + .content_length + .ok_or_else(|| Error::MissingSize { + href: self.href.clone(), + })?; + + Ok(size) + } + + /// Returns this objects metadata as [`ObjectMeta`] + pub(crate) fn object_meta(&self, base_url: &Url) -> Result { + let last_modified = self.prop_stat.prop.last_modified; + Ok(ObjectMeta { + location: self.path(base_url)?, + last_modified, + size: self.size()?, + e_tag: self.prop_stat.prop.e_tag.clone(), + version: None, + }) + } + + /// Returns true if this is a directory / collection + pub(crate) fn is_dir(&self) -> bool { + self.prop_stat.prop.resource_type.collection.is_some() + } +} + +#[derive(Deserialize)] +pub(crate) struct PropStat { + prop: Prop, + status: String, +} + +#[derive(Deserialize)] +pub(crate) struct Prop { + #[serde(deserialize_with = "deserialize_rfc1123", rename = "getlastmodified")] + last_modified: DateTime, + + #[serde(rename = "getcontentlength")] + content_length: Option, + + #[serde(rename = "resourcetype")] + resource_type: ResourceType, + + #[serde(rename = "getetag")] + e_tag: Option, +} + +#[derive(Deserialize)] +pub(crate) struct ResourceType { + collection: Option<()>, +} diff --git a/src/http/mod.rs b/src/http/mod.rs new file mode 100644 index 0000000..9786d83 --- /dev/null +++ b/src/http/mod.rs @@ -0,0 +1,290 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for generic HTTP servers +//! +//! This follows [rfc2518] commonly known as [WebDAV] +//! +//! Basic get support will work out of the box with most HTTP servers, +//! even those that don't explicitly support [rfc2518] +//! +//! Other operations such as list, delete, copy, etc... will likely +//! require server-side configuration. A list of HTTP servers with support +//! can be found [here](https://wiki.archlinux.org/title/WebDAV#Server) +//! +//! Multipart uploads are not currently supported +//! +//! [rfc2518]: https://datatracker.ietf.org/doc/html/rfc2518 +//! [WebDAV]: https://en.wikipedia.org/wiki/WebDAV + +use std::sync::Arc; + +use async_trait::async_trait; +use futures::stream::BoxStream; +use futures::{StreamExt, TryStreamExt}; +use itertools::Itertools; +use url::Url; + +use crate::client::get::GetClientExt; +use crate::client::header::get_etag; +use crate::client::{http_connector, HttpConnector}; +use crate::http::client::Client; +use crate::path::Path; +use crate::{ + ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMode, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, RetryConfig, +}; + +mod client; + +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("Must specify a URL")] + MissingUrl, + + #[error("Unable parse source url. Url: {}, Error: {}", url, source)] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[error("Unable to extract metadata from headers: {}", source)] + Metadata { + source: crate::client::header::Error, + }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + Self::Generic { + store: "HTTP", + source: Box::new(err), + } + } +} + +/// An [`ObjectStore`] implementation for generic HTTP servers +/// +/// See [`crate::http`] for more information +#[derive(Debug)] +pub struct HttpStore { + client: Arc, +} + +impl std::fmt::Display for HttpStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "HttpStore") + } +} + +#[async_trait] +impl ObjectStore for HttpStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + if opts.mode != PutMode::Overwrite { + // TODO: Add support for If header - https://datatracker.ietf.org/doc/html/rfc2518#section-9.4 + return Err(crate::Error::NotImplemented); + } + + let response = self.client.put(location, payload, opts.attributes).await?; + let e_tag = match get_etag(response.headers()) { + Ok(e_tag) => Some(e_tag), + Err(crate::client::header::Error::MissingEtag) => None, + Err(source) => return Err(Error::Metadata { source }.into()), + }; + + Ok(PutResult { + e_tag, + version: None, + }) + } + + async fn put_multipart_opts( + &self, + _location: &Path, + _opts: PutMultipartOpts, + ) -> Result> { + Err(crate::Error::NotImplemented) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.client.get_opts(location, options).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete(location).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let prefix_len = prefix.map(|p| p.as_ref().len()).unwrap_or_default(); + let prefix = prefix.cloned(); + let client = Arc::clone(&self.client); + futures::stream::once(async move { + let status = client.list(prefix.as_ref(), "infinity").await?; + + let iter = status + .response + .into_iter() + .filter(|r| !r.is_dir()) + .map(move |response| { + response.check_ok()?; + response.object_meta(client.base_url()) + }) + // Filter out exact prefix matches + .filter_ok(move |r| r.location.as_ref().len() > prefix_len); + + Ok::<_, crate::Error>(futures::stream::iter(iter)) + }) + .try_flatten() + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let status = self.client.list(prefix, "1").await?; + let prefix_len = prefix.map(|p| p.as_ref().len()).unwrap_or(0); + + let mut objects: Vec = Vec::with_capacity(status.response.len()); + let mut common_prefixes = Vec::with_capacity(status.response.len()); + for response in status.response { + response.check_ok()?; + match response.is_dir() { + false => { + let meta = response.object_meta(self.client.base_url())?; + // Filter out exact prefix matches + if meta.location.as_ref().len() > prefix_len { + objects.push(meta); + } + } + true => { + let path = response.path(self.client.base_url())?; + // Exclude the current object + if path.as_ref().len() > prefix_len { + common_prefixes.push(path); + } + } + } + } + + Ok(ListResult { + common_prefixes, + objects, + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy(from, to, true).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy(from, to, false).await + } +} + +/// Configure a connection to a generic HTTP server +#[derive(Debug, Default, Clone)] +pub struct HttpBuilder { + url: Option, + client_options: ClientOptions, + retry_config: RetryConfig, + http_connector: Option>, +} + +impl HttpBuilder { + /// Create a new [`HttpBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Set the URL + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Set individual client configuration without overriding the entire config + pub fn with_config(mut self, key: ClientConfigKey, value: impl Into) -> Self { + self.client_options = self.client_options.with_config(key, value); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// The [`HttpConnector`] to use + /// + /// On non-WASM32 platforms uses [`reqwest`] by default, on WASM32 platforms must be provided + pub fn with_http_connector(mut self, connector: C) -> Self { + self.http_connector = Some(Arc::new(connector)); + self + } + + /// Build an [`HttpStore`] with the configured options + pub fn build(self) -> Result { + let url = self.url.ok_or(Error::MissingUrl)?; + let parsed = Url::parse(&url).map_err(|source| Error::UnableToParseUrl { url, source })?; + + let client = http_connector(self.http_connector)?.connect(&self.client_options)?; + + Ok(HttpStore { + client: Arc::new(Client::new( + parsed, + client, + self.client_options, + self.retry_config, + )), + }) + } +} + +#[cfg(test)] +mod tests { + use crate::integration::*; + use crate::tests::*; + + use super::*; + + #[tokio::test] + async fn http_test() { + maybe_skip_integration!(); + let url = std::env::var("HTTP_URL").expect("HTTP_URL must be set"); + let options = ClientOptions::new().with_allow_http(true); + let integration = HttpBuilder::new() + .with_url(url) + .with_client_options(options) + .build() + .unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + } +} diff --git a/src/integration.rs b/src/integration.rs new file mode 100644 index 0000000..5a133f7 --- /dev/null +++ b/src/integration.rs @@ -0,0 +1,1229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for custom object store implementations +//! +//! NB: These tests will delete everything present in the provided [`DynObjectStore`]. +//! +//! These tests are not a stable part of the public API and breaking changes may be made +//! in patch releases. +//! +//! They are intended solely for testing purposes. + +use core::str; + +use crate::multipart::MultipartStore; +use crate::path::Path; +use crate::{ + Attribute, Attributes, DynObjectStore, Error, GetOptions, GetRange, MultipartUpload, + ObjectStore, PutMode, PutPayload, UpdateVersion, WriteMultipart, +}; +use bytes::Bytes; +use futures::stream::FuturesUnordered; +use futures::{StreamExt, TryStreamExt}; +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; + +pub(crate) async fn flatten_list_stream( + storage: &DynObjectStore, + prefix: Option<&Path>, +) -> crate::Result> { + storage + .list(prefix) + .map_ok(|meta| meta.location) + .try_collect::>() + .await +} + +/// Tests basic read/write and listing operations +pub async fn put_get_delete_list(storage: &DynObjectStore) { + delete_fixtures(storage).await; + + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!( + content_list.is_empty(), + "Expected list to be empty; found: {content_list:?}" + ); + + let location = Path::from("test_dir/test_file.json"); + + let data = Bytes::from("arbitrary data"); + storage.put(&location, data.clone().into()).await.unwrap(); + + let root = Path::from("/"); + + // List everything + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(content_list, &[location.clone()]); + + // Should behave the same as no prefix + let content_list = flatten_list_stream(storage, Some(&root)).await.unwrap(); + assert_eq!(content_list, &[location.clone()]); + + // List with delimiter + let result = storage.list_with_delimiter(None).await.unwrap(); + assert_eq!(&result.objects, &[]); + assert_eq!(result.common_prefixes.len(), 1); + assert_eq!(result.common_prefixes[0], Path::from("test_dir")); + + // Should behave the same as no prefix + let result = storage.list_with_delimiter(Some(&root)).await.unwrap(); + assert!(result.objects.is_empty()); + assert_eq!(result.common_prefixes.len(), 1); + assert_eq!(result.common_prefixes[0], Path::from("test_dir")); + + // Should return not found + let err = storage.get(&Path::from("test_dir")).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + // Should return not found + let err = storage.head(&Path::from("test_dir")).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + // List everything starting with a prefix that should return results + let prefix = Path::from("test_dir"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert_eq!(content_list, &[location.clone()]); + + // List everything starting with a prefix that shouldn't return results + let prefix = Path::from("something"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert!(content_list.is_empty()); + + let read_data = storage.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(&*read_data, data); + + // Test range request + let range = 3..7; + let range_result = storage.get_range(&location, range.clone()).await; + + let bytes = range_result.unwrap(); + assert_eq!(bytes, data.slice(range.start as usize..range.end as usize)); + + let opts = GetOptions { + range: Some(GetRange::Bounded(2..5)), + ..Default::default() + }; + let result = storage.get_opts(&location, opts).await.unwrap(); + // Data is `"arbitrary data"`, length 14 bytes + assert_eq!(result.meta.size, 14); // Should return full object size (#5272) + assert_eq!(result.range, 2..5); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes, b"bit".as_ref()); + + let out_of_range = 200..300; + let out_of_range_result = storage.get_range(&location, out_of_range).await; + + // Should be a non-fatal error + out_of_range_result.unwrap_err(); + + let opts = GetOptions { + range: Some(GetRange::Bounded(2..100)), + ..Default::default() + }; + let result = storage.get_opts(&location, opts).await.unwrap(); + assert_eq!(result.range, 2..14); + assert_eq!(result.meta.size, 14); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes, b"bitrary data".as_ref()); + + let opts = GetOptions { + range: Some(GetRange::Suffix(2)), + ..Default::default() + }; + match storage.get_opts(&location, opts).await { + Ok(result) => { + assert_eq!(result.range, 12..14); + assert_eq!(result.meta.size, 14); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes, b"ta".as_ref()); + } + Err(Error::NotSupported { .. }) => {} + Err(e) => panic!("{e}"), + } + + let opts = GetOptions { + range: Some(GetRange::Suffix(100)), + ..Default::default() + }; + match storage.get_opts(&location, opts).await { + Ok(result) => { + assert_eq!(result.range, 0..14); + assert_eq!(result.meta.size, 14); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes, b"arbitrary data".as_ref()); + } + Err(Error::NotSupported { .. }) => {} + Err(e) => panic!("{e}"), + } + + let opts = GetOptions { + range: Some(GetRange::Offset(3)), + ..Default::default() + }; + let result = storage.get_opts(&location, opts).await.unwrap(); + assert_eq!(result.range, 3..14); + assert_eq!(result.meta.size, 14); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes, b"itrary data".as_ref()); + + let opts = GetOptions { + range: Some(GetRange::Offset(100)), + ..Default::default() + }; + storage.get_opts(&location, opts).await.unwrap_err(); + + let ranges = vec![0..1, 2..3, 0..5]; + let bytes = storage.get_ranges(&location, &ranges).await.unwrap(); + for (range, bytes) in ranges.iter().zip(bytes) { + assert_eq!(bytes, data.slice(range.start as usize..range.end as usize)); + } + + let head = storage.head(&location).await.unwrap(); + assert_eq!(head.size, data.len() as u64); + + storage.delete(&location).await.unwrap(); + + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!(content_list.is_empty()); + + let err = storage.get(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + let err = storage.head(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + // Test handling of paths containing an encoded delimiter + + let file_with_delimiter = Path::from_iter(["a", "b/c", "foo.file"]); + storage + .put(&file_with_delimiter, "arbitrary".into()) + .await + .unwrap(); + + let files = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(files, vec![file_with_delimiter.clone()]); + + let files = flatten_list_stream(storage, Some(&Path::from("a/b"))) + .await + .unwrap(); + assert!(files.is_empty()); + + let files = storage + .list_with_delimiter(Some(&Path::from("a/b"))) + .await + .unwrap(); + assert!(files.common_prefixes.is_empty()); + assert!(files.objects.is_empty()); + + let files = storage + .list_with_delimiter(Some(&Path::from("a"))) + .await + .unwrap(); + assert_eq!(files.common_prefixes, vec![Path::from_iter(["a", "b/c"])]); + assert!(files.objects.is_empty()); + + let files = storage + .list_with_delimiter(Some(&Path::from_iter(["a", "b/c"]))) + .await + .unwrap(); + assert!(files.common_prefixes.is_empty()); + assert_eq!(files.objects.len(), 1); + assert_eq!(files.objects[0].location, file_with_delimiter); + + storage.delete(&file_with_delimiter).await.unwrap(); + + // Test handling of paths containing non-ASCII characters, e.g. emoji + + let emoji_prefix = Path::from("🙀"); + let emoji_file = Path::from("🙀/😀.parquet"); + storage.put(&emoji_file, "arbitrary".into()).await.unwrap(); + + storage.head(&emoji_file).await.unwrap(); + storage + .get(&emoji_file) + .await + .unwrap() + .bytes() + .await + .unwrap(); + + let files = flatten_list_stream(storage, Some(&emoji_prefix)) + .await + .unwrap(); + + assert_eq!(files, vec![emoji_file.clone()]); + + let dst = Path::from("foo.parquet"); + storage.copy(&emoji_file, &dst).await.unwrap(); + let mut files = flatten_list_stream(storage, None).await.unwrap(); + files.sort_unstable(); + assert_eq!(files, vec![emoji_file.clone(), dst.clone()]); + + let dst2 = Path::from("new/nested/foo.parquet"); + storage.copy(&emoji_file, &dst2).await.unwrap(); + let mut files = flatten_list_stream(storage, None).await.unwrap(); + files.sort_unstable(); + assert_eq!(files, vec![emoji_file.clone(), dst.clone(), dst2.clone()]); + + let dst3 = Path::from("new/nested2/bar.parquet"); + storage.rename(&dst, &dst3).await.unwrap(); + let mut files = flatten_list_stream(storage, None).await.unwrap(); + files.sort_unstable(); + assert_eq!(files, vec![emoji_file.clone(), dst2.clone(), dst3.clone()]); + + let err = storage.head(&dst).await.unwrap_err(); + assert!(matches!(err, Error::NotFound { .. })); + + storage.delete(&emoji_file).await.unwrap(); + storage.delete(&dst3).await.unwrap(); + storage.delete(&dst2).await.unwrap(); + let files = flatten_list_stream(storage, Some(&emoji_prefix)) + .await + .unwrap(); + assert!(files.is_empty()); + + // Test handling of paths containing percent-encoded sequences + + // "HELLO" percent encoded + let hello_prefix = Path::parse("%48%45%4C%4C%4F").unwrap(); + let path = hello_prefix.child("foo.parquet"); + + storage.put(&path, vec![0, 1].into()).await.unwrap(); + let files = flatten_list_stream(storage, Some(&hello_prefix)) + .await + .unwrap(); + assert_eq!(files, vec![path.clone()]); + + // Cannot list by decoded representation + let files = flatten_list_stream(storage, Some(&Path::from("HELLO"))) + .await + .unwrap(); + assert!(files.is_empty()); + + // Cannot access by decoded representation + let err = storage + .head(&Path::from("HELLO/foo.parquet")) + .await + .unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + storage.delete(&path).await.unwrap(); + + // Test handling of unicode paths + let path = Path::parse("🇦🇺/$shenanigans@@~.txt").unwrap(); + storage.put(&path, "test".into()).await.unwrap(); + + let r = storage.get(&path).await.unwrap(); + assert_eq!(r.bytes().await.unwrap(), "test"); + + let dir = Path::parse("🇦🇺").unwrap(); + let r = storage.list_with_delimiter(None).await.unwrap(); + assert!(r.common_prefixes.contains(&dir)); + + let r = storage.list_with_delimiter(Some(&dir)).await.unwrap(); + assert_eq!(r.objects.len(), 1); + assert_eq!(r.objects[0].location, path); + + storage.delete(&path).await.unwrap(); + + // Can also write non-percent encoded sequences + let path = Path::parse("%Q.parquet").unwrap(); + storage.put(&path, vec![0, 1].into()).await.unwrap(); + + let files = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(files, vec![path.clone()]); + + storage.delete(&path).await.unwrap(); + + let path = Path::parse("foo bar/I contain spaces.parquet").unwrap(); + storage.put(&path, vec![0, 1].into()).await.unwrap(); + storage.head(&path).await.unwrap(); + + let files = flatten_list_stream(storage, Some(&Path::from("foo bar"))) + .await + .unwrap(); + assert_eq!(files, vec![path.clone()]); + + storage.delete(&path).await.unwrap(); + + let files = flatten_list_stream(storage, None).await.unwrap(); + assert!(files.is_empty(), "{files:?}"); + + // Test list order + let files = vec![ + Path::from("a a/b.file"), + Path::parse("a%2Fa.file").unwrap(), + Path::from("a/😀.file"), + Path::from("a/a file"), + Path::parse("a/a%2F.file").unwrap(), + Path::from("a/a.file"), + Path::from("a/a/b.file"), + Path::from("a/b.file"), + Path::from("aa/a.file"), + Path::from("ab/a.file"), + ]; + + for file in &files { + storage.put(file, "foo".into()).await.unwrap(); + } + + let cases = [ + (None, Path::from("a")), + (None, Path::from("a/a file")), + (None, Path::from("a/a/b.file")), + (None, Path::from("ab/a.file")), + (None, Path::from("a%2Fa.file")), + (None, Path::from("a/😀.file")), + (Some(Path::from("a")), Path::from("")), + (Some(Path::from("a")), Path::from("a")), + (Some(Path::from("a")), Path::from("a/😀")), + (Some(Path::from("a")), Path::from("a/😀.file")), + (Some(Path::from("a")), Path::from("a/b")), + (Some(Path::from("a")), Path::from("a/a/b.file")), + ]; + + for (prefix, offset) in cases { + let s = storage.list_with_offset(prefix.as_ref(), &offset); + let mut actual: Vec<_> = s.map_ok(|x| x.location).try_collect().await.unwrap(); + + actual.sort_unstable(); + + let expected: Vec<_> = files + .iter() + .filter(|x| { + let prefix_match = prefix.as_ref().map(|p| x.prefix_matches(p)).unwrap_or(true); + prefix_match && *x > &offset + }) + .cloned() + .collect(); + + assert_eq!(actual, expected, "{prefix:?} - {offset:?}"); + } + + // Test bulk delete + let paths = vec![ + Path::from("a/a.file"), + Path::from("a/a/b.file"), + Path::from("aa/a.file"), + Path::from("does_not_exist"), + Path::from("I'm a < & weird path"), + Path::from("ab/a.file"), + Path::from("a/😀.file"), + ]; + + storage.put(&paths[4], "foo".into()).await.unwrap(); + + let out_paths = storage + .delete_stream(futures::stream::iter(paths.clone()).map(Ok).boxed()) + .collect::>() + .await; + + assert_eq!(out_paths.len(), paths.len()); + + let expect_errors = [3]; + + for (i, input_path) in paths.iter().enumerate() { + let err = storage.head(input_path).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + if expect_errors.contains(&i) { + // Some object stores will report NotFound, but others (such as S3) will + // report success regardless. + match &out_paths[i] { + Err(Error::NotFound { path: out_path, .. }) => { + assert!(out_path.ends_with(&input_path.to_string())); + } + Ok(out_path) => { + assert_eq!(out_path, input_path); + } + _ => panic!("unexpected error"), + } + } else { + assert_eq!(out_paths[i].as_ref().unwrap(), input_path); + } + } + + delete_fixtures(storage).await; + + let path = Path::from("empty"); + storage.put(&path, PutPayload::default()).await.unwrap(); + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.size, 0); + let data = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(data.len(), 0); + + storage.delete(&path).await.unwrap(); +} + +/// Tests the ability to read and write [`Attributes`] +pub async fn put_get_attributes(integration: &dyn ObjectStore) { + // Test handling of attributes + let attributes = Attributes::from_iter([ + (Attribute::CacheControl, "max-age=604800"), + ( + Attribute::ContentDisposition, + r#"attachment; filename="test.html""#, + ), + (Attribute::ContentEncoding, "gzip"), + (Attribute::ContentLanguage, "en-US"), + (Attribute::ContentType, "text/html; charset=utf-8"), + (Attribute::Metadata("test_key".into()), "test_value"), + ]); + + let path = Path::from("attributes"); + let opts = attributes.clone().into(); + match integration.put_opts(&path, "foo".into(), opts).await { + Ok(_) => { + let r = integration.get(&path).await.unwrap(); + assert_eq!(r.attributes, attributes); + } + Err(Error::NotImplemented) => {} + Err(e) => panic!("{e}"), + } + + let opts = attributes.clone().into(); + match integration.put_multipart_opts(&path, opts).await { + Ok(mut w) => { + w.put_part("foo".into()).await.unwrap(); + w.complete().await.unwrap(); + + let r = integration.get(&path).await.unwrap(); + assert_eq!(r.attributes, attributes); + } + Err(Error::NotImplemented) => {} + Err(e) => panic!("{e}"), + } +} + +/// Tests conditional read requests +pub async fn get_opts(storage: &dyn ObjectStore) { + let path = Path::from("test"); + storage.put(&path, "foo".into()).await.unwrap(); + let meta = storage.head(&path).await.unwrap(); + + let options = GetOptions { + if_unmodified_since: Some(meta.last_modified), + ..GetOptions::default() + }; + match storage.get_opts(&path, options).await { + Ok(_) | Err(Error::NotSupported { .. }) => {} + Err(e) => panic!("{e}"), + } + + let options = GetOptions { + if_unmodified_since: Some(meta.last_modified + chrono::Duration::try_hours(10).unwrap()), + ..GetOptions::default() + }; + match storage.get_opts(&path, options).await { + Ok(_) | Err(Error::NotSupported { .. }) => {} + Err(e) => panic!("{e}"), + } + + let options = GetOptions { + if_unmodified_since: Some(meta.last_modified - chrono::Duration::try_hours(10).unwrap()), + ..GetOptions::default() + }; + match storage.get_opts(&path, options).await { + Err(Error::Precondition { .. } | Error::NotSupported { .. }) => {} + d => panic!("{d:?}"), + } + + let options = GetOptions { + if_modified_since: Some(meta.last_modified), + ..GetOptions::default() + }; + match storage.get_opts(&path, options).await { + Err(Error::NotModified { .. } | Error::NotSupported { .. }) => {} + d => panic!("{d:?}"), + } + + let options = GetOptions { + if_modified_since: Some(meta.last_modified - chrono::Duration::try_hours(10).unwrap()), + ..GetOptions::default() + }; + match storage.get_opts(&path, options).await { + Ok(_) | Err(Error::NotSupported { .. }) => {} + Err(e) => panic!("{e}"), + } + + let tag = meta.e_tag.unwrap(); + let options = GetOptions { + if_match: Some(tag.clone()), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); + + let options = GetOptions { + if_match: Some("invalid".to_string()), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + let options = GetOptions { + if_none_match: Some(tag.clone()), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::NotModified { .. }), "{err}"); + + let options = GetOptions { + if_none_match: Some("invalid".to_string()), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); + + let result = storage.put(&path, "test".into()).await.unwrap(); + let new_tag = result.e_tag.unwrap(); + assert_ne!(tag, new_tag); + + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.e_tag.unwrap(), new_tag); + + let options = GetOptions { + if_match: Some(new_tag), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); + + let options = GetOptions { + if_match: Some(tag), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + if let Some(version) = meta.version { + storage.put(&path, "bar".into()).await.unwrap(); + + let options = GetOptions { + version: Some(version), + ..GetOptions::default() + }; + + // Can retrieve previous version + let get_opts = storage.get_opts(&path, options).await.unwrap(); + let old = get_opts.bytes().await.unwrap(); + assert_eq!(old, b"test".as_slice()); + + // Current version contains the updated data + let current = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(¤t, b"bar".as_slice()); + } +} + +/// Tests conditional writes +pub async fn put_opts(storage: &dyn ObjectStore, supports_update: bool) { + // When using DynamoCommit repeated runs of this test will produce the same sequence of records in DynamoDB + // As a result each conditional operation will need to wait for the lease to timeout before proceeding + // One solution would be to clear DynamoDB before each test, but this would require non-trivial additional code + // so we instead just generate a random suffix for the filenames + let rng = thread_rng(); + let suffix = String::from_utf8(rng.sample_iter(Alphanumeric).take(32).collect()).unwrap(); + + delete_fixtures(storage).await; + let path = Path::from(format!("put_opts_{suffix}")); + let v1 = storage + .put_opts(&path, "a".into(), PutMode::Create.into()) + .await + .unwrap(); + + let err = storage + .put_opts(&path, "b".into(), PutMode::Create.into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::AlreadyExists { .. }), "{err}"); + + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(b.as_ref(), b"a"); + + if !supports_update { + let err = storage + .put_opts(&path, "c".into(), PutMode::Update(v1.clone().into()).into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::NotImplemented { .. }), "{err}"); + + return; + } + + let v2 = storage + .put_opts(&path, "c".into(), PutMode::Update(v1.clone().into()).into()) + .await + .unwrap(); + + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(b.as_ref(), b"c"); + + let err = storage + .put_opts(&path, "d".into(), PutMode::Update(v1.into()).into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + storage + .put_opts(&path, "e".into(), PutMode::Update(v2.clone().into()).into()) + .await + .unwrap(); + + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(b.as_ref(), b"e"); + + // Update not exists + let path = Path::from("I don't exist"); + let err = storage + .put_opts(&path, "e".into(), PutMode::Update(v2.into()).into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + const NUM_WORKERS: usize = 5; + const NUM_INCREMENTS: usize = 10; + + let path = Path::from(format!("RACE-{suffix}")); + let mut futures: FuturesUnordered<_> = (0..NUM_WORKERS) + .map(|_| async { + for _ in 0..NUM_INCREMENTS { + loop { + match storage.get(&path).await { + Ok(r) => { + let mode = PutMode::Update(UpdateVersion { + e_tag: r.meta.e_tag.clone(), + version: r.meta.version.clone(), + }); + + let b = r.bytes().await.unwrap(); + let v: usize = std::str::from_utf8(&b).unwrap().parse().unwrap(); + let new = (v + 1).to_string(); + + match storage.put_opts(&path, new.into(), mode.into()).await { + Ok(_) => break, + Err(Error::Precondition { .. }) => continue, + Err(e) => return Err(e), + } + } + Err(Error::NotFound { .. }) => { + let mode = PutMode::Create; + match storage.put_opts(&path, "1".into(), mode.into()).await { + Ok(_) => break, + Err(Error::AlreadyExists { .. }) => continue, + Err(e) => return Err(e), + } + } + Err(e) => return Err(e), + } + } + } + Ok(()) + }) + .collect(); + + while futures.next().await.transpose().unwrap().is_some() {} + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + let v = std::str::from_utf8(&b).unwrap().parse::().unwrap(); + assert_eq!(v, NUM_WORKERS * NUM_INCREMENTS); +} + +/// Returns a chunk of length `chunk_length` +fn get_chunk(chunk_length: usize) -> Bytes { + let mut data = vec![0_u8; chunk_length]; + let mut rng = thread_rng(); + // Set a random selection of bytes + for _ in 0..1000 { + data[rng.gen_range(0..chunk_length)] = rng.gen(); + } + data.into() +} + +/// Returns `num_chunks` of length `chunks` +fn get_chunks(chunk_length: usize, num_chunks: usize) -> Vec { + (0..num_chunks).map(|_| get_chunk(chunk_length)).collect() +} + +/// Tests the ability to perform multipart writes +pub async fn stream_get(storage: &DynObjectStore) { + let location = Path::from("test_dir/test_upload_file.txt"); + + // Can write to storage + let data = get_chunks(5 * 1024 * 1024, 3); + let bytes_expected = data.concat(); + let mut upload = storage.put_multipart(&location).await.unwrap(); + let uploads = data.into_iter().map(|x| upload.put_part(x.into())); + futures::future::try_join_all(uploads).await.unwrap(); + + // Object should not yet exist in store + let meta_res = storage.head(&location).await; + assert!(meta_res.is_err()); + assert!(matches!( + meta_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + let files = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(&files, &[]); + + let result = storage.list_with_delimiter(None).await.unwrap(); + assert_eq!(&result.objects, &[]); + + upload.complete().await.unwrap(); + + let bytes_written = storage.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(bytes_expected, bytes_written); + + // Can overwrite some storage + // Sizes chosen to ensure we write three parts + let data = get_chunks(3_200_000, 7); + let bytes_expected = data.concat(); + let upload = storage.put_multipart(&location).await.unwrap(); + let mut writer = WriteMultipart::new(upload); + for chunk in &data { + writer.write(chunk) + } + writer.finish().await.unwrap(); + let bytes_written = storage.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(bytes_expected, bytes_written); + + let location = Path::from("test_dir/test_put_part.txt"); + let upload = storage.put_multipart(&location).await.unwrap(); + let mut write = WriteMultipart::new(upload); + write.put(vec![0; 2].into()); + write.put(vec![3; 4].into()); + write.finish().await.unwrap(); + + let meta = storage.head(&location).await.unwrap(); + assert_eq!(meta.size, 6); + + let location = Path::from("test_dir/test_put_part_mixed.txt"); + let upload = storage.put_multipart(&location).await.unwrap(); + let mut write = WriteMultipart::new(upload); + write.put(vec![0; 2].into()); + write.write(&[1, 2, 3]); + write.put(vec![4, 5, 6, 7].into()); + write.finish().await.unwrap(); + + let r = storage.get(&location).await.unwrap(); + let r = r.bytes().await.unwrap(); + assert_eq!(r.as_ref(), &[0, 0, 1, 2, 3, 4, 5, 6, 7]); + + // We can abort an empty write + let location = Path::from("test_dir/test_abort_upload.txt"); + let mut upload = storage.put_multipart(&location).await.unwrap(); + upload.abort().await.unwrap(); + let get_res = storage.get(&location).await; + assert!(get_res.is_err()); + assert!(matches!( + get_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + // We can abort an in-progress write + let mut upload = storage.put_multipart(&location).await.unwrap(); + upload + .put_part(data.first().unwrap().clone().into()) + .await + .unwrap(); + + upload.abort().await.unwrap(); + let get_res = storage.get(&location).await; + assert!(get_res.is_err()); + assert!(matches!(get_res.unwrap_err(), Error::NotFound { .. })); +} + +/// Tests that directories are transparent +pub async fn list_uses_directories_correctly(storage: &DynObjectStore) { + delete_fixtures(storage).await; + + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!( + content_list.is_empty(), + "Expected list to be empty; found: {content_list:?}" + ); + + let location1 = Path::from("foo/x.json"); + let location2 = Path::from("foo.bar/y.json"); + + let data = PutPayload::from("arbitrary data"); + storage.put(&location1, data.clone()).await.unwrap(); + storage.put(&location2, data).await.unwrap(); + + let prefix = Path::from("foo"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert_eq!(content_list, &[location1.clone()]); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + assert_eq!(result.objects.len(), 1); + assert_eq!(result.objects[0].location, location1); + assert_eq!(result.common_prefixes, &[]); + + // Listing an existing path (file) should return an empty list: + // https://github.com/apache/arrow-rs/issues/3712 + let content_list = flatten_list_stream(storage, Some(&location1)) + .await + .unwrap(); + assert_eq!(content_list, &[]); + + let list = storage.list_with_delimiter(Some(&location1)).await.unwrap(); + assert_eq!(list.objects, &[]); + assert_eq!(list.common_prefixes, &[]); + + let prefix = Path::from("foo/x"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert_eq!(content_list, &[]); + + let list = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + assert_eq!(list.objects, &[]); + assert_eq!(list.common_prefixes, &[]); +} + +/// Tests listing with delimiter +pub async fn list_with_delimiter(storage: &DynObjectStore) { + delete_fixtures(storage).await; + + // ==================== check: store is empty ==================== + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!(content_list.is_empty()); + + // ==================== do: create files ==================== + let data = Bytes::from("arbitrary data"); + + let files: Vec<_> = [ + "test_file", + "mydb/wb/000/000/000.segment", + "mydb/wb/000/000/001.segment", + "mydb/wb/000/000/002.segment", + "mydb/wb/001/001/000.segment", + "mydb/wb/foo.json", + "mydb/wbwbwb/111/222/333.segment", + "mydb/data/whatevs", + ] + .iter() + .map(|&s| Path::from(s)) + .collect(); + + for f in &files { + storage.put(f, data.clone().into()).await.unwrap(); + } + + // ==================== check: prefix-list `mydb/wb` (directory) ==================== + let prefix = Path::from("mydb/wb"); + + let expected_000 = Path::from("mydb/wb/000"); + let expected_001 = Path::from("mydb/wb/001"); + let expected_location = Path::from("mydb/wb/foo.json"); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + + assert_eq!(result.common_prefixes, vec![expected_000, expected_001]); + assert_eq!(result.objects.len(), 1); + + let object = &result.objects[0]; + + assert_eq!(object.location, expected_location); + assert_eq!(object.size, data.len() as u64); + + // ==================== check: prefix-list `mydb/wb/000/000/001` (partial filename doesn't match) ==================== + let prefix = Path::from("mydb/wb/000/000/001"); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + assert!(result.common_prefixes.is_empty()); + assert_eq!(result.objects.len(), 0); + + // ==================== check: prefix-list `not_there` (non-existing prefix) ==================== + let prefix = Path::from("not_there"); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + assert!(result.common_prefixes.is_empty()); + assert!(result.objects.is_empty()); + + // ==================== do: remove all files ==================== + for f in &files { + storage.delete(f).await.unwrap(); + } + + // ==================== check: store is empty ==================== + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!(content_list.is_empty()); +} + +/// Tests fetching a non-existent object returns a not found error +pub async fn get_nonexistent_object( + storage: &DynObjectStore, + location: Option, +) -> crate::Result { + let location = location.unwrap_or_else(|| Path::from("this_file_should_not_exist")); + + let err = storage.head(&location).await.unwrap_err(); + assert!(matches!(err, Error::NotFound { .. })); + + storage.get(&location).await?.bytes().await +} + +/// Tests copying +pub async fn rename_and_copy(storage: &DynObjectStore) { + // Create two objects + let path1 = Path::from("test1"); + let path2 = Path::from("test2"); + let contents1 = Bytes::from("cats"); + let contents2 = Bytes::from("dogs"); + + // copy() make both objects identical + storage.put(&path1, contents1.clone().into()).await.unwrap(); + storage.put(&path2, contents2.clone().into()).await.unwrap(); + storage.copy(&path1, &path2).await.unwrap(); + let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); + assert_eq!(&new_contents, &contents1); + + // rename() copies contents and deletes original + storage.put(&path1, contents1.clone().into()).await.unwrap(); + storage.put(&path2, contents2.clone().into()).await.unwrap(); + storage.rename(&path1, &path2).await.unwrap(); + let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); + assert_eq!(&new_contents, &contents1); + let result = storage.get(&path1).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Error::NotFound { .. })); + + // Clean up + storage.delete(&path2).await.unwrap(); +} + +/// Tests copy if not exists +pub async fn copy_if_not_exists(storage: &DynObjectStore) { + // Create two objects + let path1 = Path::from("test1"); + let path2 = Path::from("not_exists_nested/test2"); + let contents1 = Bytes::from("cats"); + let contents2 = Bytes::from("dogs"); + + // copy_if_not_exists() errors if destination already exists + storage.put(&path1, contents1.clone().into()).await.unwrap(); + storage.put(&path2, contents2.clone().into()).await.unwrap(); + let result = storage.copy_if_not_exists(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::Error::AlreadyExists { .. } + )); + + // copy_if_not_exists() copies contents and allows deleting original + storage.delete(&path2).await.unwrap(); + storage.copy_if_not_exists(&path1, &path2).await.unwrap(); + storage.delete(&path1).await.unwrap(); + let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); + assert_eq!(&new_contents, &contents1); + let result = storage.get(&path1).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // Clean up + storage.delete(&path2).await.unwrap(); +} + +/// Tests copy and renaming behaviour of non-existent objects +pub async fn copy_rename_nonexistent_object(storage: &DynObjectStore) { + // Create empty source object + let path1 = Path::from("test1"); + + // Create destination object + let path2 = Path::from("test2"); + storage.put(&path2, "hello".into()).await.unwrap(); + + // copy() errors if source does not exist + let result = storage.copy(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // rename() errors if source does not exist + let result = storage.rename(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // copy_if_not_exists() errors if source does not exist + let result = storage.copy_if_not_exists(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // Clean up + storage.delete(&path2).await.unwrap(); +} + +/// Tests [`MultipartStore`] +pub async fn multipart(storage: &dyn ObjectStore, multipart: &dyn MultipartStore) { + let path = Path::from("test_multipart"); + let chunk_size = 5 * 1024 * 1024; + + let chunks = get_chunks(chunk_size, 2); + + let id = multipart.create_multipart(&path).await.unwrap(); + + let parts: Vec<_> = futures::stream::iter(chunks) + .enumerate() + .map(|(idx, b)| multipart.put_part(&path, &id, idx, b.into())) + .buffered(2) + .try_collect() + .await + .unwrap(); + + multipart + .complete_multipart(&path, &id, parts) + .await + .unwrap(); + + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.size, chunk_size as u64 * 2); + + // Empty case + let path = Path::from("test_empty_multipart"); + + let id = multipart.create_multipart(&path).await.unwrap(); + + let parts = vec![]; + + multipart + .complete_multipart(&path, &id, parts) + .await + .unwrap(); + + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.size, 0); +} + +async fn delete_fixtures(storage: &DynObjectStore) { + let paths = storage.list(None).map_ok(|meta| meta.location).boxed(); + storage + .delete_stream(paths) + .try_collect::>() + .await + .unwrap(); +} + +/// Tests a race condition where 2 threads are performing multipart writes to the same path +pub async fn multipart_race_condition(storage: &dyn ObjectStore, last_writer_wins: bool) { + let path = Path::from("test_multipart_race_condition"); + + let mut multipart_upload_1 = storage.put_multipart(&path).await.unwrap(); + let mut multipart_upload_2 = storage.put_multipart(&path).await.unwrap(); + + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 0)).into()) + .await + .unwrap(); + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 0)).into()) + .await + .unwrap(); + + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 1)).into()) + .await + .unwrap(); + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 1)).into()) + .await + .unwrap(); + + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 2)).into()) + .await + .unwrap(); + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 2)).into()) + .await + .unwrap(); + + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 3)).into()) + .await + .unwrap(); + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 3)).into()) + .await + .unwrap(); + + multipart_upload_1 + .put_part(Bytes::from(format!("1:{:05300000},", 4)).into()) + .await + .unwrap(); + multipart_upload_2 + .put_part(Bytes::from(format!("2:{:05300000},", 4)).into()) + .await + .unwrap(); + + multipart_upload_1.complete().await.unwrap(); + + if last_writer_wins { + multipart_upload_2.complete().await.unwrap(); + } else { + let err = multipart_upload_2.complete().await.unwrap_err(); + + assert!(matches!(err, crate::Error::Generic { .. }), "{err}"); + } + + let get_result = storage.get(&path).await.unwrap(); + let bytes = get_result.bytes().await.unwrap(); + let string_contents = str::from_utf8(&bytes).unwrap(); + + if last_writer_wins { + assert!(string_contents.starts_with( + format!( + "2:{:05300000},2:{:05300000},2:{:05300000},2:{:05300000},2:{:05300000},", + 0, 1, 2, 3, 4 + ) + .as_str() + )); + } else { + assert!(string_contents.starts_with( + format!( + "1:{:05300000},1:{:05300000},1:{:05300000},1:{:05300000},1:{:05300000},", + 0, 1, 2, 3, 4 + ) + .as_str() + )); + } +} + +/// Tests performing out of order multipart uploads +pub async fn multipart_out_of_order(storage: &dyn ObjectStore) { + let path = Path::from("test_multipart_out_of_order"); + let mut multipart_upload = storage.put_multipart(&path).await.unwrap(); + + let part1 = std::iter::repeat(b'1') + .take(5 * 1024 * 1024) + .collect::(); + let part2 = std::iter::repeat(b'2') + .take(5 * 1024 * 1024) + .collect::(); + let part3 = std::iter::repeat(b'3') + .take(5 * 1024 * 1024) + .collect::(); + let full = [part1.as_ref(), part2.as_ref(), part3.as_ref()].concat(); + + let fut1 = multipart_upload.put_part(part1.into()); + let fut2 = multipart_upload.put_part(part2.into()); + let fut3 = multipart_upload.put_part(part3.into()); + // note order is 2,3,1 , different than the parts were created in + fut2.await.unwrap(); + fut3.await.unwrap(); + fut1.await.unwrap(); + + multipart_upload.complete().await.unwrap(); + + let result = storage.get(&path).await.unwrap(); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes, full); +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ec660df --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,1629 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr, + unreachable_pub +)] + +//! # object_store +//! +//! This crate provides a uniform API for interacting with object +//! storage services and local files via the [`ObjectStore`] +//! trait. +//! +//! Using this crate, the same binary and code can run in multiple +//! clouds and local test environments, via a simple runtime +//! configuration change. +//! +//! # Highlights +//! +//! 1. A high-performance async API focused on providing a consistent interface +//! mirroring that of object stores such as [S3] +//! +//! 2. Production quality, leading this crate to be used in large +//! scale production systems, such as [crates.io] and [InfluxDB IOx] +//! +//! 3. Support for advanced functionality, including atomic, conditional reads +//! and writes, vectored IO, bulk deletion, and more... +//! +//! 4. Stable and predictable governance via the [Apache Arrow] project +//! +//! 5. Small dependency footprint, depending on only a small number of common crates +//! +//! Originally developed by [InfluxData] and subsequently donated +//! to [Apache Arrow]. +//! +//! [Apache Arrow]: https://arrow.apache.org/ +//! [InfluxData]: https://www.influxdata.com/ +//! [crates.io]: https://github.com/rust-lang/crates.io +//! [ACID]: https://en.wikipedia.org/wiki/ACID +//! [S3]: https://aws.amazon.com/s3/ +//! +//! # Available [`ObjectStore`] Implementations +//! +//! By default, this crate provides the following implementations: +//! +//! * Memory: [`InMemory`](memory::InMemory) +//! +//! Feature flags are used to enable support for other implementations: +//! +#![cfg_attr( + feature = "fs", + doc = "* Local filesystem: [`LocalFileSystem`](local::LocalFileSystem)" +)] +#![cfg_attr( + feature = "gcp", + doc = "* [`gcp`]: [Google Cloud Storage](https://cloud.google.com/storage/) support. See [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder)" +)] +#![cfg_attr( + feature = "aws", + doc = "* [`aws`]: [Amazon S3](https://aws.amazon.com/s3/). See [`AmazonS3Builder`](aws::AmazonS3Builder)" +)] +#![cfg_attr( + feature = "azure", + doc = "* [`azure`]: [Azure Blob Storage](https://azure.microsoft.com/en-gb/services/storage/blobs/). See [`MicrosoftAzureBuilder`](azure::MicrosoftAzureBuilder)" +)] +#![cfg_attr( + feature = "http", + doc = "* [`http`]: [HTTP/WebDAV Storage](https://datatracker.ietf.org/doc/html/rfc2518). See [`HttpBuilder`](http::HttpBuilder)" +)] +//! +//! # Why not a Filesystem Interface? +//! +//! The [`ObjectStore`] interface is designed to mirror the APIs +//! of object stores and *not* filesystems, and thus has stateless APIs instead +//! of cursor based interfaces such as [`Read`] or [`Seek`] available in filesystems. +//! +//! This design provides the following advantages: +//! +//! * All operations are atomic, and readers cannot observe partial and/or failed writes +//! * Methods map directly to object store APIs, providing both efficiency and predictability +//! * Abstracts away filesystem and operating system specific quirks, ensuring portability +//! * Allows for functionality not native to filesystems, such as operation preconditions +//! and atomic multipart uploads +//! +//! This crate does provide [`BufReader`] and [`BufWriter`] adapters +//! which provide a more filesystem-like API for working with the +//! [`ObjectStore`] trait, however, they should be used with care +//! +//! [`BufReader`]: buffered::BufReader +//! [`BufWriter`]: buffered::BufWriter +//! +//! # Adapters +//! +//! [`ObjectStore`] instances can be composed with various adapters +//! which add additional functionality: +//! +//! * Rate Throttling: [`ThrottleConfig`](throttle::ThrottleConfig) +//! * Concurrent Request Limit: [`LimitStore`](limit::LimitStore) +//! +//! # Configuration System +//! +//! This crate provides a configuration system inspired by the APIs exposed by [fsspec], +//! [PyArrow FileSystem], and [Hadoop FileSystem], allowing creating a [`DynObjectStore`] +//! from a URL and an optional list of key value pairs. This provides a flexible interface +//! to support a wide variety of user-defined store configurations, with minimal additional +//! application complexity. +//! +//! ```no_run +//! # #[cfg(feature = "aws")] { +//! # use url::Url; +//! # use object_store::{parse_url, parse_url_opts}; +//! # use object_store::aws::{AmazonS3, AmazonS3Builder}; +//! # +//! # +//! // Can manually create a specific store variant using the appropriate builder +//! let store: AmazonS3 = AmazonS3Builder::from_env() +//! .with_bucket_name("my-bucket").build().unwrap(); +//! +//! // Alternatively can create an ObjectStore from an S3 URL +//! let url = Url::parse("s3://bucket/path").unwrap(); +//! let (store, path) = parse_url(&url).unwrap(); +//! assert_eq!(path.as_ref(), "path"); +//! +//! // Potentially with additional options +//! let (store, path) = parse_url_opts(&url, vec![("aws_access_key_id", "...")]).unwrap(); +//! +//! // Or with URLs that encode the bucket name in the URL path +//! let url = Url::parse("https://ACCOUNT_ID.r2.cloudflarestorage.com/bucket/path").unwrap(); +//! let (store, path) = parse_url(&url).unwrap(); +//! assert_eq!(path.as_ref(), "path"); +//! # } +//! ``` +//! +//! [PyArrow FileSystem]: https://arrow.apache.org/docs/python/generated/pyarrow.fs.FileSystem.html#pyarrow.fs.FileSystem.from_uri +//! [fsspec]: https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem +//! [Hadoop FileSystem]: https://hadoop.apache.org/docs/r3.0.0/api/org/apache/hadoop/fs/FileSystem.html#get-java.net.URI-org.apache.hadoop.conf.Configuration- +//! +//! # List objects +//! +//! Use the [`ObjectStore::list`] method to iterate over objects in +//! remote storage or files in the local filesystem: +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # use std::sync::Arc; +//! # use object_store::{path::Path, ObjectStore}; +//! # use futures::stream::StreamExt; +//! # // use LocalFileSystem for example +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) +//! # } +//! # +//! # async fn example() { +//! # +//! // create an ObjectStore +//! let object_store: Arc = get_object_store(); +//! +//! // Recursively list all files below the 'data' path. +//! // 1. On AWS S3 this would be the 'data/' prefix +//! // 2. On a local filesystem, this would be the 'data' directory +//! let prefix = Path::from("data"); +//! +//! // Get an `async` stream of Metadata objects: +//! let mut list_stream = object_store.list(Some(&prefix)); +//! +//! // Print a line about each object +//! while let Some(meta) = list_stream.next().await.transpose().unwrap() { +//! println!("Name: {}, size: {}", meta.location, meta.size); +//! } +//! # } +//! ``` +//! +//! Which will print out something like the following: +//! +//! ```text +//! Name: data/file01.parquet, size: 112832 +//! Name: data/file02.parquet, size: 143119 +//! Name: data/child/file03.parquet, size: 100 +//! ... +//! ``` +//! +//! # Fetch objects +//! +//! Use the [`ObjectStore::get`] method to fetch the data bytes +//! from remote storage or files in the local filesystem as a stream. +//! +//! ``` +//! # use futures::TryStreamExt; +//! # use object_store::local::LocalFileSystem; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use object_store::{path::Path, ObjectStore, GetResult}; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) +//! # } +//! # +//! # async fn example() { +//! # +//! // Create an ObjectStore +//! let object_store: Arc = get_object_store(); +//! +//! // Retrieve a specific file +//! let path = Path::from("data/file01.parquet"); +//! +//! // Fetch just the file metadata +//! let meta = object_store.head(&path).await.unwrap(); +//! println!("{meta:?}"); +//! +//! // Fetch the object including metadata +//! let result: GetResult = object_store.get(&path).await.unwrap(); +//! assert_eq!(result.meta, meta); +//! +//! // Buffer the entire object in memory +//! let object: Bytes = result.bytes().await.unwrap(); +//! assert_eq!(object.len() as u64, meta.size); +//! +//! // Alternatively stream the bytes from object storage +//! let stream = object_store.get(&path).await.unwrap().into_stream(); +//! +//! // Count the '0's using `try_fold` from `TryStreamExt` trait +//! let num_zeros = stream +//! .try_fold(0, |acc, bytes| async move { +//! Ok(acc + bytes.iter().filter(|b| **b == 0).count()) +//! }).await.unwrap(); +//! +//! println!("Num zeros in {} is {}", path, num_zeros); +//! # } +//! ``` +//! +//! # Put Object +//! +//! Use the [`ObjectStore::put`] method to atomically write data. +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # use object_store::{ObjectStore, PutPayload}; +//! # use std::sync::Arc; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) +//! # } +//! # async fn put() { +//! # +//! let object_store: Arc = get_object_store(); +//! let path = Path::from("data/file1"); +//! let payload = PutPayload::from_static(b"hello"); +//! object_store.put(&path, payload).await.unwrap(); +//! # } +//! ``` +//! +//! # Multipart Upload +//! +//! Use the [`ObjectStore::put_multipart`] method to atomically write a large amount of data +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # use object_store::{ObjectStore, WriteMultipart}; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) +//! # } +//! # async fn multi_upload() { +//! # +//! let object_store: Arc = get_object_store(); +//! let path = Path::from("data/large_file"); +//! let upload = object_store.put_multipart(&path).await.unwrap(); +//! let mut write = WriteMultipart::new(upload); +//! write.write(b"hello"); +//! write.finish().await.unwrap(); +//! # } +//! ``` +//! +//! # Vectored Read +//! +//! A common pattern, especially when reading structured datasets, is to need to fetch +//! multiple, potentially non-contiguous, ranges of a particular object. +//! +//! [`ObjectStore::get_ranges`] provides an efficient way to perform such vectored IO, and will +//! automatically coalesce adjacent ranges into an appropriate number of parallel requests. +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # use object_store::ObjectStore; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) +//! # } +//! # async fn multi_upload() { +//! # +//! let object_store: Arc = get_object_store(); +//! let path = Path::from("data/large_file"); +//! let ranges = object_store.get_ranges(&path, &[90..100, 400..600, 0..10]).await.unwrap(); +//! assert_eq!(ranges.len(), 3); +//! assert_eq!(ranges[0].len(), 10); +//! # } +//! ``` +//! +//! # Vectored Write +//! +//! When writing data it is often the case that the size of the output is not known ahead of time. +//! +//! A common approach to handling this is to bump-allocate a `Vec`, whereby the underlying +//! allocation is repeatedly reallocated, each time doubling the capacity. The performance of +//! this is suboptimal as reallocating memory will often involve copying it to a new location. +//! +//! Fortunately, as [`PutPayload`] does not require memory regions to be contiguous, it is +//! possible to instead allocate memory in chunks and avoid bump allocating. [`PutPayloadMut`] +//! encapsulates this approach +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # use object_store::{ObjectStore, PutPayloadMut}; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) +//! # } +//! # async fn multi_upload() { +//! # +//! let object_store: Arc = get_object_store(); +//! let path = Path::from("data/large_file"); +//! let mut buffer = PutPayloadMut::new().with_block_size(8192); +//! for _ in 0..22 { +//! buffer.extend_from_slice(&[0; 1024]); +//! } +//! let payload = buffer.freeze(); +//! +//! // Payload consists of 3 separate 8KB allocations +//! assert_eq!(payload.as_ref().len(), 3); +//! assert_eq!(payload.as_ref()[0].len(), 8192); +//! assert_eq!(payload.as_ref()[1].len(), 8192); +//! assert_eq!(payload.as_ref()[2].len(), 6144); +//! +//! object_store.put(&path, payload).await.unwrap(); +//! # } +//! ``` +//! +//! # Conditional Fetch +//! +//! More complex object retrieval can be supported by [`ObjectStore::get_opts`]. +//! +//! For example, efficiently refreshing a cache without re-fetching the entire object +//! data if the object hasn't been modified. +//! +//! ``` +//! # use std::collections::btree_map::Entry; +//! # use std::collections::HashMap; +//! # use object_store::{GetOptions, GetResult, ObjectStore, Result, Error}; +//! # use std::sync::Arc; +//! # use std::time::{Duration, Instant}; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::path::Path; +//! struct CacheEntry { +//! /// Data returned by last request +//! data: Bytes, +//! /// ETag identifying the object returned by the server +//! e_tag: String, +//! /// Instant of last refresh +//! refreshed_at: Instant, +//! } +//! +//! /// Example cache that checks entries after 10 seconds for a new version +//! struct Cache { +//! entries: HashMap, +//! store: Arc, +//! } +//! +//! impl Cache { +//! pub async fn get(&mut self, path: &Path) -> Result { +//! Ok(match self.entries.get_mut(path) { +//! Some(e) => match e.refreshed_at.elapsed() < Duration::from_secs(10) { +//! true => e.data.clone(), // Return cached data +//! false => { // Check if remote version has changed +//! let opts = GetOptions { +//! if_none_match: Some(e.e_tag.clone()), +//! ..GetOptions::default() +//! }; +//! match self.store.get_opts(&path, opts).await { +//! Ok(d) => e.data = d.bytes().await?, +//! Err(Error::NotModified { .. }) => {} // Data has not changed +//! Err(e) => return Err(e), +//! }; +//! e.refreshed_at = Instant::now(); +//! e.data.clone() +//! } +//! }, +//! None => { // Not cached, fetch data +//! let get = self.store.get(&path).await?; +//! let e_tag = get.meta.e_tag.clone(); +//! let data = get.bytes().await?; +//! if let Some(e_tag) = e_tag { +//! let entry = CacheEntry { +//! e_tag, +//! data: data.clone(), +//! refreshed_at: Instant::now(), +//! }; +//! self.entries.insert(path.clone(), entry); +//! } +//! data +//! } +//! }) +//! } +//! } +//! ``` +//! +//! # Conditional Put +//! +//! The default behaviour when writing data is to upsert any existing object at the given path, +//! overwriting any previous value. More complex behaviours can be achieved using [`PutMode`], and +//! can be used to build [Optimistic Concurrency Control] based transactions. This facilitates +//! building metadata catalogs, such as [Apache Iceberg] or [Delta Lake], directly on top of object +//! storage, without relying on a separate DBMS. +//! +//! ``` +//! # use object_store::{Error, ObjectStore, PutMode, UpdateVersion}; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::memory::InMemory; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(InMemory::new()) +//! # } +//! # fn do_update(b: Bytes) -> Bytes {b} +//! # async fn conditional_put() { +//! let store = get_object_store(); +//! let path = Path::from("test"); +//! +//! // Perform a conditional update on path +//! loop { +//! // Perform get request +//! let r = store.get(&path).await.unwrap(); +//! +//! // Save version information fetched +//! let version = UpdateVersion { +//! e_tag: r.meta.e_tag.clone(), +//! version: r.meta.version.clone(), +//! }; +//! +//! // Compute new version of object contents +//! let new = do_update(r.bytes().await.unwrap()); +//! +//! // Attempt to commit transaction +//! match store.put_opts(&path, new.into(), PutMode::Update(version).into()).await { +//! Ok(_) => break, // Successfully committed +//! Err(Error::Precondition { .. }) => continue, // Object has changed, try again +//! Err(e) => panic!("{e}") +//! } +//! } +//! # } +//! ``` +//! +//! [Optimistic Concurrency Control]: https://en.wikipedia.org/wiki/Optimistic_concurrency_control +//! [Apache Iceberg]: https://iceberg.apache.org/ +//! [Delta Lake]: https://delta.io/ +//! +//! # TLS Certificates +//! +//! Stores that use HTTPS/TLS (this is true for most cloud stores) can choose the source of their [CA] +//! certificates. By default the system-bundled certificates are used (see +//! [`rustls-native-certs`]). The `tls-webpki-roots` feature switch can be used to also bundle Mozilla's +//! root certificates with the library/application (see [`webpki-roots`]). +//! +//! [CA]: https://en.wikipedia.org/wiki/Certificate_authority +//! [`rustls-native-certs`]: https://crates.io/crates/rustls-native-certs/ +//! [`webpki-roots`]: https://crates.io/crates/webpki-roots +//! + +#[cfg(feature = "aws")] +pub mod aws; +#[cfg(feature = "azure")] +pub mod azure; +pub mod buffered; +#[cfg(not(target_arch = "wasm32"))] +pub mod chunked; +pub mod delimited; +#[cfg(feature = "gcp")] +pub mod gcp; +#[cfg(feature = "http")] +pub mod http; +pub mod limit; +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] +pub mod local; +pub mod memory; +pub mod path; +pub mod prefix; +#[cfg(feature = "cloud")] +pub mod signer; +pub mod throttle; + +#[cfg(feature = "cloud")] +pub mod client; + +#[cfg(feature = "cloud")] +pub use client::{ + backoff::BackoffConfig, retry::RetryConfig, ClientConfigKey, ClientOptions, CredentialProvider, + StaticCredentialProvider, +}; + +#[cfg(all(feature = "cloud", not(target_arch = "wasm32")))] +pub use client::Certificate; + +#[cfg(feature = "cloud")] +mod config; + +mod tags; + +pub use tags::TagSet; + +pub mod multipart; +mod parse; +mod payload; +mod upload; +mod util; + +mod attributes; + +#[cfg(any(feature = "integration", test))] +pub mod integration; + +pub use attributes::*; + +pub use parse::{parse_url, parse_url_opts, ObjectStoreScheme}; +pub use payload::*; +pub use upload::*; +pub use util::{coalesce_ranges, collect_bytes, GetRange, OBJECT_STORE_COALESCE_DEFAULT}; + +use crate::path::Path; +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] +use crate::util::maybe_spawn_blocking; +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use std::fmt::{Debug, Formatter}; +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] +use std::io::{Read, Seek, SeekFrom}; +use std::ops::Range; +use std::sync::Arc; + +/// An alias for a dynamically dispatched object store implementation. +pub type DynObjectStore = dyn ObjectStore; + +/// Id type for multipart uploads. +pub type MultipartId = String; + +/// Universal API to multiple object store services. +#[async_trait] +pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { + /// Save the provided bytes to the specified location + /// + /// The operation is guaranteed to be atomic, it will either successfully + /// write the entirety of `payload` to `location`, or fail. No clients + /// should be able to observe a partially written object + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + self.put_opts(location, payload, PutOptions::default()) + .await + } + + /// Save the provided `payload` to `location` with the given options + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result; + + /// Perform a multipart upload + /// + /// Client should prefer [`ObjectStore::put`] for small payloads, as streaming uploads + /// typically require multiple separate requests. See [`MultipartUpload`] for more information + async fn put_multipart(&self, location: &Path) -> Result> { + self.put_multipart_opts(location, PutMultipartOpts::default()) + .await + } + + /// Perform a multipart upload with options + /// + /// Client should prefer [`ObjectStore::put`] for small payloads, as streaming uploads + /// typically require multiple separate requests. See [`MultipartUpload`] for more information + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result>; + + /// Return the bytes that are stored at the specified location. + async fn get(&self, location: &Path) -> Result { + self.get_opts(location, GetOptions::default()).await + } + + /// Perform a get request with options + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result; + + /// Return the bytes that are stored at the specified location + /// in the given byte range. + /// + /// See [`GetRange::Bounded`] for more details on how `range` gets interpreted + async fn get_range(&self, location: &Path, range: Range) -> Result { + let options = GetOptions { + range: Some(range.into()), + ..Default::default() + }; + self.get_opts(location, options).await?.bytes().await + } + + /// Return the bytes that are stored at the specified location + /// in the given byte ranges + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> Result> { + coalesce_ranges( + ranges, + |range| self.get_range(location, range), + OBJECT_STORE_COALESCE_DEFAULT, + ) + .await + } + + /// Return the metadata for the specified location + async fn head(&self, location: &Path) -> Result { + let options = GetOptions { + head: true, + ..Default::default() + }; + Ok(self.get_opts(location, options).await?.meta) + } + + /// Delete the object at the specified location. + async fn delete(&self, location: &Path) -> Result<()>; + + /// Delete all the objects at the specified locations + /// + /// When supported, this method will use bulk operations that delete more + /// than one object per a request. The default implementation will call + /// the single object delete method for each location, but with up to 10 + /// concurrent requests. + /// + /// The returned stream yields the results of the delete operations in the + /// same order as the input locations. However, some errors will be from + /// an overall call to a bulk delete operation, and not from a specific + /// location. + /// + /// If the object did not exist, the result may be an error or a success, + /// depending on the behavior of the underlying store. For example, local + /// filesystems, GCP, and Azure return an error, while S3 and in-memory will + /// return Ok. If it is an error, it will be [`Error::NotFound`]. + /// + /// ``` + /// # use futures::{StreamExt, TryStreamExt}; + /// # use object_store::local::LocalFileSystem; + /// # async fn example() -> Result<(), Box> { + /// # let root = tempfile::TempDir::new().unwrap(); + /// # let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + /// # use object_store::{ObjectStore, ObjectMeta}; + /// # use object_store::path::Path; + /// # use futures::{StreamExt, TryStreamExt}; + /// # + /// // Create two objects + /// store.put(&Path::from("foo"), "foo".into()).await?; + /// store.put(&Path::from("bar"), "bar".into()).await?; + /// + /// // List object + /// let locations = store.list(None).map_ok(|m| m.location).boxed(); + /// + /// // Delete them + /// store.delete_stream(locations).try_collect::>().await?; + /// # Ok(()) + /// # } + /// # let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); + /// # rt.block_on(example()).unwrap(); + /// ``` + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + locations + .map(|location| async { + let location = location?; + self.delete(&location).await?; + Ok(location) + }) + .buffered(10) + .boxed() + } + + /// List all the objects with the given prefix. + /// + /// Prefixes are evaluated on a path segment basis, i.e. `foo/bar` is a prefix of `foo/bar/x` but not of + /// `foo/bar_baz/x`. List is recursive, i.e. `foo/bar/more/x` will be included. + /// + /// Note: the order of returned [`ObjectMeta`] is not guaranteed + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result>; + + /// List all the objects with the given prefix and a location greater than `offset` + /// + /// Some stores, such as S3 and GCS, may be able to push `offset` down to reduce + /// the number of network requests required + /// + /// Note: the order of returned [`ObjectMeta`] is not guaranteed + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + let offset = offset.clone(); + self.list(prefix) + .try_filter(move |f| futures::future::ready(f.location > offset)) + .boxed() + } + + /// List objects with the given prefix and an implementation specific + /// delimiter. Returns common prefixes (directories) in addition to object + /// metadata. + /// + /// Prefixes are evaluated on a path segment basis, i.e. `foo/bar` is a prefix of `foo/bar/x` but not of + /// `foo/bar_baz/x`. List is not recursive, i.e. `foo/bar/more/x` will not be included. + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result; + + /// Copy an object from one path to another in the same object store. + /// + /// If there exists an object at the destination, it will be overwritten. + async fn copy(&self, from: &Path, to: &Path) -> Result<()>; + + /// Move an object from one path to another in the same object store. + /// + /// By default, this is implemented as a copy and then delete source. It may not + /// check when deleting source that it was the same object that was originally copied. + /// + /// If there exists an object at the destination, it will be overwritten. + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + self.copy(from, to).await?; + self.delete(from).await + } + + /// Copy an object from one path to another, only if destination is empty. + /// + /// Will return an error if the destination already has an object. + /// + /// Performs an atomic operation if the underlying object storage supports it. + /// If atomic operations are not supported by the underlying object storage (like S3) + /// it will return an error. + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()>; + + /// Move an object from one path to another in the same object store. + /// + /// Will return an error if the destination already has an object. + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.copy_if_not_exists(from, to).await?; + self.delete(from).await + } +} + +macro_rules! as_ref_impl { + ($type:ty) => { + #[async_trait] + impl ObjectStore for $type { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + self.as_ref().put(location, payload).await + } + + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.as_ref().put_opts(location, payload, opts).await + } + + async fn put_multipart(&self, location: &Path) -> Result> { + self.as_ref().put_multipart(location).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + self.as_ref().put_multipart_opts(location, opts).await + } + + async fn get(&self, location: &Path) -> Result { + self.as_ref().get(location).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.as_ref().get_opts(location, options).await + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + self.as_ref().get_range(location, range).await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + self.as_ref().get_ranges(location, ranges).await + } + + async fn head(&self, location: &Path) -> Result { + self.as_ref().head(location).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.as_ref().delete(location).await + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + self.as_ref().delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.as_ref().list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + self.as_ref().list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.as_ref().list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().rename(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().copy_if_not_exists(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().rename_if_not_exists(from, to).await + } + } + }; +} + +as_ref_impl!(Arc); +as_ref_impl!(Box); + +/// Result of a list call that includes objects, prefixes (directories) and a +/// token for the next set of results. Individual result sets may be limited to +/// 1,000 objects based on the underlying object storage's limitations. +#[derive(Debug)] +pub struct ListResult { + /// Prefixes that are common (like directories) + pub common_prefixes: Vec, + /// Object metadata for the listing + pub objects: Vec, +} + +/// The metadata that describes an object. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ObjectMeta { + /// The full path to the object + pub location: Path, + /// The last modified time + pub last_modified: DateTime, + /// The size in bytes of the object. + /// + /// Note this is not `usize` as `object_store` supports 32-bit architectures such as WASM + pub size: u64, + /// The unique identifier for the object + /// + /// + pub e_tag: Option, + /// A version indicator for this object + pub version: Option, +} + +/// Options for a get request, such as range +#[derive(Debug, Default, Clone)] +pub struct GetOptions { + /// Request will succeed if the `ObjectMeta::e_tag` matches + /// otherwise returning [`Error::Precondition`] + /// + /// See + /// + /// Examples: + /// + /// ```text + /// If-Match: "xyzzy" + /// If-Match: "xyzzy", "r2d2xxxx", "c3piozzzz" + /// If-Match: * + /// ``` + pub if_match: Option, + /// Request will succeed if the `ObjectMeta::e_tag` does not match + /// otherwise returning [`Error::NotModified`] + /// + /// See + /// + /// Examples: + /// + /// ```text + /// If-None-Match: "xyzzy" + /// If-None-Match: "xyzzy", "r2d2xxxx", "c3piozzzz" + /// If-None-Match: * + /// ``` + pub if_none_match: Option, + /// Request will succeed if the object has been modified since + /// + /// + pub if_modified_since: Option>, + /// Request will succeed if the object has not been modified since + /// otherwise returning [`Error::Precondition`] + /// + /// Some stores, such as S3, will only return `NotModified` for exact + /// timestamp matches, instead of for any timestamp greater than or equal. + /// + /// + pub if_unmodified_since: Option>, + /// Request transfer of only the specified range of bytes + /// otherwise returning [`Error::NotModified`] + /// + /// + pub range: Option, + /// Request a particular object version + pub version: Option, + /// Request transfer of no content + /// + /// + pub head: bool, + /// Implementation-specific extensions. Intended for use by [`ObjectStore`] implementations + /// that need to pass context-specific information (like tracing spans) via trait methods. + /// + /// These extensions are ignored entirely by backends offered through this crate. + pub extensions: ::http::Extensions, +} + +impl GetOptions { + /// Returns an error if the modification conditions on this request are not satisfied + /// + /// + fn check_preconditions(&self, meta: &ObjectMeta) -> Result<()> { + // The use of the invalid etag "*" means no ETag is equivalent to never matching + let etag = meta.e_tag.as_deref().unwrap_or("*"); + let last_modified = meta.last_modified; + + if let Some(m) = &self.if_match { + if m != "*" && m.split(',').map(str::trim).all(|x| x != etag) { + return Err(Error::Precondition { + path: meta.location.to_string(), + source: format!("{etag} does not match {m}").into(), + }); + } + } else if let Some(date) = self.if_unmodified_since { + if last_modified > date { + return Err(Error::Precondition { + path: meta.location.to_string(), + source: format!("{date} < {last_modified}").into(), + }); + } + } + + if let Some(m) = &self.if_none_match { + if m == "*" || m.split(',').map(str::trim).any(|x| x == etag) { + return Err(Error::NotModified { + path: meta.location.to_string(), + source: format!("{etag} matches {m}").into(), + }); + } + } else if let Some(date) = self.if_modified_since { + if last_modified <= date { + return Err(Error::NotModified { + path: meta.location.to_string(), + source: format!("{date} >= {last_modified}").into(), + }); + } + } + Ok(()) + } +} + +/// Result for a get request +#[derive(Debug)] +pub struct GetResult { + /// The [`GetResultPayload`] + pub payload: GetResultPayload, + /// The [`ObjectMeta`] for this object + pub meta: ObjectMeta, + /// The range of bytes returned by this request + /// + /// Note this is not `usize` as `object_store` supports 32-bit architectures such as WASM + pub range: Range, + /// Additional object attributes + pub attributes: Attributes, +} + +/// The kind of a [`GetResult`] +/// +/// This special cases the case of a local file, as some systems may +/// be able to optimise the case of a file already present on local disk +pub enum GetResultPayload { + /// The file, path + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + File(std::fs::File, std::path::PathBuf), + /// An opaque stream of bytes + Stream(BoxStream<'static, Result>), +} + +impl Debug for GetResultPayload { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + Self::File(_, _) => write!(f, "GetResultPayload(File)"), + Self::Stream(_) => write!(f, "GetResultPayload(Stream)"), + } + } +} + +impl GetResult { + /// Collects the data into a [`Bytes`] + pub async fn bytes(self) -> Result { + let len = self.range.end - self.range.start; + match self.payload { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + GetResultPayload::File(mut file, path) => { + maybe_spawn_blocking(move || { + file.seek(SeekFrom::Start(self.range.start as _)) + .map_err(|source| local::Error::Seek { + source, + path: path.clone(), + })?; + + let mut buffer = if let Ok(len) = len.try_into() { + Vec::with_capacity(len) + } else { + Vec::new() + }; + file.take(len as _) + .read_to_end(&mut buffer) + .map_err(|source| local::Error::UnableToReadBytes { source, path })?; + + Ok(buffer.into()) + }) + .await + } + GetResultPayload::Stream(s) => collect_bytes(s, Some(len)).await, + } + } + + /// Converts this into a byte stream + /// + /// If the `self.kind` is [`GetResultPayload::File`] will perform chunked reads of the file, + /// otherwise will return the [`GetResultPayload::Stream`]. + /// + /// # Tokio Compatibility + /// + /// Tokio discourages performing blocking IO on a tokio worker thread, however, + /// no major operating systems have stable async file APIs. Therefore if called from + /// a tokio context, this will use [`tokio::runtime::Handle::spawn_blocking`] to dispatch + /// IO to a blocking thread pool, much like `tokio::fs` does under-the-hood. + /// + /// If not called from a tokio context, this will perform IO on the current thread with + /// no additional complexity or overheads + pub fn into_stream(self) -> BoxStream<'static, Result> { + match self.payload { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + GetResultPayload::File(file, path) => { + const CHUNK_SIZE: usize = 8 * 1024; + local::chunked_stream(file, path, self.range, CHUNK_SIZE) + } + GetResultPayload::Stream(s) => s, + } + } +} + +/// Configure preconditions for the put operation +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum PutMode { + /// Perform an atomic write operation, overwriting any object present at the provided path + #[default] + Overwrite, + /// Perform an atomic write operation, returning [`Error::AlreadyExists`] if an + /// object already exists at the provided path + Create, + /// Perform an atomic write operation if the current version of the object matches the + /// provided [`UpdateVersion`], returning [`Error::Precondition`] otherwise + Update(UpdateVersion), +} + +/// Uniquely identifies a version of an object to update +/// +/// Stores will use differing combinations of `e_tag` and `version` to provide conditional +/// updates, and it is therefore recommended applications preserve both +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UpdateVersion { + /// The unique identifier for the newly created object + /// + /// + pub e_tag: Option, + /// A version indicator for the newly created object + pub version: Option, +} + +impl From for UpdateVersion { + fn from(value: PutResult) -> Self { + Self { + e_tag: value.e_tag, + version: value.version, + } + } +} + +/// Options for a put request +#[derive(Debug, Clone, Default)] +pub struct PutOptions { + /// Configure the [`PutMode`] for this operation + pub mode: PutMode, + /// Provide a [`TagSet`] for this object + /// + /// Implementations that don't support object tagging should ignore this + pub tags: TagSet, + /// Provide a set of [`Attributes`] + /// + /// Implementations that don't support an attribute should return an error + pub attributes: Attributes, + /// Implementation-specific extensions. Intended for use by [`ObjectStore`] implementations + /// that need to pass context-specific information (like tracing spans) via trait methods. + /// + /// These extensions are ignored entirely by backends offered through this crate. + /// + /// They are also eclused from [`PartialEq`] and [`Eq`]. + pub extensions: ::http::Extensions, +} + +impl PartialEq for PutOptions { + fn eq(&self, other: &Self) -> bool { + let Self { + mode, + tags, + attributes, + extensions: _, + } = self; + let Self { + mode: other_mode, + tags: other_tags, + attributes: other_attributes, + extensions: _, + } = other; + (mode == other_mode) && (tags == other_tags) && (attributes == other_attributes) + } +} + +impl Eq for PutOptions {} + +impl From for PutOptions { + fn from(mode: PutMode) -> Self { + Self { + mode, + ..Default::default() + } + } +} + +impl From for PutOptions { + fn from(tags: TagSet) -> Self { + Self { + tags, + ..Default::default() + } + } +} + +impl From for PutOptions { + fn from(attributes: Attributes) -> Self { + Self { + attributes, + ..Default::default() + } + } +} + +/// Options for [`ObjectStore::put_multipart_opts`] +#[derive(Debug, Clone, Default)] +pub struct PutMultipartOpts { + /// Provide a [`TagSet`] for this object + /// + /// Implementations that don't support object tagging should ignore this + pub tags: TagSet, + /// Provide a set of [`Attributes`] + /// + /// Implementations that don't support an attribute should return an error + pub attributes: Attributes, + /// Implementation-specific extensions. Intended for use by [`ObjectStore`] implementations + /// that need to pass context-specific information (like tracing spans) via trait methods. + /// + /// These extensions are ignored entirely by backends offered through this crate. + /// + /// They are also eclused from [`PartialEq`] and [`Eq`]. + pub extensions: ::http::Extensions, +} + +impl PartialEq for PutMultipartOpts { + fn eq(&self, other: &Self) -> bool { + let Self { + tags, + attributes, + extensions: _, + } = self; + let Self { + tags: other_tags, + attributes: other_attributes, + extensions: _, + } = other; + (tags == other_tags) && (attributes == other_attributes) + } +} + +impl Eq for PutMultipartOpts {} + +impl From for PutMultipartOpts { + fn from(tags: TagSet) -> Self { + Self { + tags, + ..Default::default() + } + } +} + +impl From for PutMultipartOpts { + fn from(attributes: Attributes) -> Self { + Self { + attributes, + ..Default::default() + } + } +} + +/// Result for a put request +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PutResult { + /// The unique identifier for the newly created object + /// + /// + pub e_tag: Option, + /// A version indicator for the newly created object + pub version: Option, +} + +/// A specialized `Result` for object store-related errors +pub type Result = std::result::Result; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error { + /// A fallback error type when no variant matches + #[error("Generic {} error: {}", store, source)] + Generic { + /// The store this error originated from + store: &'static str, + /// The wrapped error + source: Box, + }, + + /// Error when the object is not found at given location + #[error("Object at location {} not found: {}", path, source)] + NotFound { + /// The path to file + path: String, + /// The wrapped error + source: Box, + }, + + /// Error for invalid path + #[error("Encountered object with invalid path: {}", source)] + InvalidPath { + /// The wrapped error + #[from] + source: path::Error, + }, + + /// Error when `tokio::spawn` failed + #[error("Error joining spawned task: {}", source)] + JoinError { + /// The wrapped error + #[from] + source: tokio::task::JoinError, + }, + + /// Error when the attempted operation is not supported + #[error("Operation not supported: {}", source)] + NotSupported { + /// The wrapped error + source: Box, + }, + + /// Error when the object already exists + #[error("Object at location {} already exists: {}", path, source)] + AlreadyExists { + /// The path to the + path: String, + /// The wrapped error + source: Box, + }, + + /// Error when the required conditions failed for the operation + #[error("Request precondition failure for path {}: {}", path, source)] + Precondition { + /// The path to the file + path: String, + /// The wrapped error + source: Box, + }, + + /// Error when the object at the location isn't modified + #[error("Object at location {} not modified: {}", path, source)] + NotModified { + /// The path to the file + path: String, + /// The wrapped error + source: Box, + }, + + /// Error when an operation is not implemented + #[error("Operation not yet implemented.")] + NotImplemented, + + /// Error when the used credentials don't have enough permission + /// to perform the requested operation + #[error( + "The operation lacked the necessary privileges to complete for path {}: {}", + path, + source + )] + PermissionDenied { + /// The path to the file + path: String, + /// The wrapped error + source: Box, + }, + + /// Error when the used credentials lack valid authentication + #[error( + "The operation lacked valid authentication credentials for path {}: {}", + path, + source + )] + Unauthenticated { + /// The path to the file + path: String, + /// The wrapped error + source: Box, + }, + + /// Error when a configuration key is invalid for the store used + #[error("Configuration key: '{}' is not valid for store '{}'.", key, store)] + UnknownConfigurationKey { + /// The object store used + store: &'static str, + /// The configuration key used + key: String, + }, +} + +impl From for std::io::Error { + fn from(e: Error) -> Self { + let kind = match &e { + Error::NotFound { .. } => std::io::ErrorKind::NotFound, + _ => std::io::ErrorKind::Other, + }; + Self::new(kind, e) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::buffered::BufWriter; + use chrono::TimeZone; + use tokio::io::AsyncWriteExt; + + macro_rules! maybe_skip_integration { + () => { + if std::env::var("TEST_INTEGRATION").is_err() { + eprintln!("Skipping integration test - set TEST_INTEGRATION"); + return; + } + }; + } + pub(crate) use maybe_skip_integration; + + /// Test that the returned stream does not borrow the lifetime of Path + fn list_store<'a>( + store: &'a dyn ObjectStore, + path_str: &str, + ) -> BoxStream<'a, Result> { + let path = Path::from(path_str); + store.list(Some(&path)) + } + + #[cfg(any(feature = "azure", feature = "aws"))] + pub(crate) async fn signing(integration: &T) + where + T: ObjectStore + signer::Signer, + { + use reqwest::Method; + use std::time::Duration; + + let data = Bytes::from("hello world"); + let path = Path::from("file.txt"); + integration.put(&path, data.clone().into()).await.unwrap(); + + let signed = integration + .signed_url(Method::GET, &path, Duration::from_secs(60)) + .await + .unwrap(); + + let resp = reqwest::get(signed).await.unwrap(); + let loaded = resp.bytes().await.unwrap(); + + assert_eq!(data, loaded); + } + + #[cfg(any(feature = "aws", feature = "azure"))] + pub(crate) async fn tagging(storage: Arc, validate: bool, get_tags: F) + where + F: Fn(Path) -> Fut + Send + Sync, + Fut: std::future::Future> + Send, + { + use bytes::Buf; + use serde::Deserialize; + + #[derive(Deserialize)] + struct Tagging { + #[serde(rename = "TagSet")] + list: TagList, + } + + #[derive(Debug, Deserialize)] + struct TagList { + #[serde(rename = "Tag")] + tags: Vec, + } + + #[derive(Debug, Deserialize, Eq, PartialEq)] + #[serde(rename_all = "PascalCase")] + struct Tag { + key: String, + value: String, + } + + let tags = vec![ + Tag { + key: "foo.com=bar/s".to_string(), + value: "bananas/foo.com-_".to_string(), + }, + Tag { + key: "namespace/key.foo".to_string(), + value: "value with a space".to_string(), + }, + ]; + let mut tag_set = TagSet::default(); + for t in &tags { + tag_set.push(&t.key, &t.value) + } + + let path = Path::from("tag_test"); + storage + .put_opts(&path, "test".into(), tag_set.clone().into()) + .await + .unwrap(); + + let multi_path = Path::from("tag_test_multi"); + let mut write = storage + .put_multipart_opts(&multi_path, tag_set.clone().into()) + .await + .unwrap(); + + write.put_part("foo".into()).await.unwrap(); + write.complete().await.unwrap(); + + let buf_path = Path::from("tag_test_buf"); + let mut buf = BufWriter::new(storage, buf_path.clone()).with_tags(tag_set); + buf.write_all(b"foo").await.unwrap(); + buf.shutdown().await.unwrap(); + + // Write should always succeed, but certain configurations may simply ignore tags + if !validate { + return; + } + + for path in [path, multi_path, buf_path] { + let resp = get_tags(path.clone()).await.unwrap(); + let body = resp.into_body().bytes().await.unwrap(); + + let mut resp: Tagging = quick_xml::de::from_reader(body.reader()).unwrap(); + resp.list.tags.sort_by(|a, b| a.key.cmp(&b.key)); + assert_eq!(resp.list.tags, tags); + } + } + + #[tokio::test] + async fn test_list_lifetimes() { + let store = memory::InMemory::new(); + let mut stream = list_store(&store, "path"); + assert!(stream.next().await.is_none()); + } + + #[test] + fn test_preconditions() { + let mut meta = ObjectMeta { + location: Path::from("test"), + last_modified: Utc.timestamp_nanos(100), + size: 100, + e_tag: Some("123".to_string()), + version: None, + }; + + let mut options = GetOptions::default(); + options.check_preconditions(&meta).unwrap(); + + options.if_modified_since = Some(Utc.timestamp_nanos(50)); + options.check_preconditions(&meta).unwrap(); + + options.if_modified_since = Some(Utc.timestamp_nanos(100)); + options.check_preconditions(&meta).unwrap_err(); + + options.if_modified_since = Some(Utc.timestamp_nanos(101)); + options.check_preconditions(&meta).unwrap_err(); + + options = GetOptions::default(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(50)); + options.check_preconditions(&meta).unwrap_err(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(100)); + options.check_preconditions(&meta).unwrap(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(101)); + options.check_preconditions(&meta).unwrap(); + + options = GetOptions::default(); + + options.if_match = Some("123".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("123,354".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("354, 123,".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("354".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_match = Some("*".to_string()); + options.check_preconditions(&meta).unwrap(); + + // If-Match takes precedence + options.if_unmodified_since = Some(Utc.timestamp_nanos(200)); + options.check_preconditions(&meta).unwrap(); + + options = GetOptions::default(); + + options.if_none_match = Some("123".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_none_match = Some("*".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_none_match = Some("1232".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_none_match = Some("23, 123".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + // If-None-Match takes precedence + options.if_modified_since = Some(Utc.timestamp_nanos(10)); + options.check_preconditions(&meta).unwrap_err(); + + // Check missing ETag + meta.e_tag = None; + options = GetOptions::default(); + + options.if_none_match = Some("*".to_string()); // Fails if any file exists + options.check_preconditions(&meta).unwrap_err(); + + options = GetOptions::default(); + options.if_match = Some("*".to_string()); // Passes if file exists + options.check_preconditions(&meta).unwrap(); + } +} diff --git a/src/limit.rs b/src/limit.rs new file mode 100644 index 0000000..330a0da --- /dev/null +++ b/src/limit.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store that limits the maximum concurrency of the wrapped implementation + +use crate::{ + BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, Path, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, StreamExt, + UploadPart, +}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::{FutureExt, Stream}; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// Store wrapper that wraps an inner store and limits the maximum number of concurrent +/// object store operations. Where each call to an [`ObjectStore`] member function is +/// considered a single operation, even if it may result in more than one network call +/// +/// ``` +/// # use object_store::memory::InMemory; +/// # use object_store::limit::LimitStore; +/// +/// // Create an in-memory `ObjectStore` limited to 20 concurrent requests +/// let store = LimitStore::new(InMemory::new(), 20); +/// ``` +/// +#[derive(Debug)] +pub struct LimitStore { + inner: Arc, + max_requests: usize, + semaphore: Arc, +} + +impl LimitStore { + /// Create new limit store that will limit the maximum + /// number of outstanding concurrent requests to + /// `max_requests` + pub fn new(inner: T, max_requests: usize) -> Self { + Self { + inner: Arc::new(inner), + max_requests, + semaphore: Arc::new(Semaphore::new(max_requests)), + } + } +} + +impl std::fmt::Display for LimitStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "LimitStore({}, {})", self.max_requests, self.inner) + } +} + +#[async_trait] +impl ObjectStore for LimitStore { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.put(location, payload).await + } + + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.put_opts(location, payload, opts).await + } + async fn put_multipart(&self, location: &Path) -> Result> { + let upload = self.inner.put_multipart(location).await?; + Ok(Box::new(LimitUpload { + semaphore: Arc::clone(&self.semaphore), + upload, + })) + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + let upload = self.inner.put_multipart_opts(location, opts).await?; + Ok(Box::new(LimitUpload { + semaphore: Arc::clone(&self.semaphore), + upload, + })) + } + + async fn get(&self, location: &Path) -> Result { + let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); + let r = self.inner.get(location).await?; + Ok(permit_get_result(r, permit)) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); + let r = self.inner.get_opts(location, options).await?; + Ok(permit_get_result(r, permit)) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.get_range(location, range).await + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> Result> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.get_ranges(location, ranges).await + } + + async fn head(&self, location: &Path) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.head(location).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.delete(location).await + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + self.inner.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let prefix = prefix.cloned(); + let inner = Arc::clone(&self.inner); + let fut = Arc::clone(&self.semaphore) + .acquire_owned() + .map(move |permit| { + let s = inner.list(prefix.as_ref()); + PermitWrapper::new(s, permit.unwrap()) + }); + fut.into_stream().flatten().boxed() + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + let prefix = prefix.cloned(); + let offset = offset.clone(); + let inner = Arc::clone(&self.inner); + let fut = Arc::clone(&self.semaphore) + .acquire_owned() + .map(move |permit| { + let s = inner.list_with_offset(prefix.as_ref(), &offset); + PermitWrapper::new(s, permit.unwrap()) + }); + fut.into_stream().flatten().boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.rename(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.copy_if_not_exists(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.rename_if_not_exists(from, to).await + } +} + +fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult { + let payload = match r.payload { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + v @ GetResultPayload::File(_, _) => v, + GetResultPayload::Stream(s) => { + GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed()) + } + }; + GetResult { payload, ..r } +} + +/// Combines an [`OwnedSemaphorePermit`] with some other type +struct PermitWrapper { + inner: T, + #[allow(dead_code)] + permit: OwnedSemaphorePermit, +} + +impl PermitWrapper { + fn new(inner: T, permit: OwnedSemaphorePermit) -> Self { + Self { inner, permit } + } +} + +impl Stream for PermitWrapper { + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +/// An [`MultipartUpload`] wrapper that limits the maximum number of concurrent requests +#[derive(Debug)] +pub struct LimitUpload { + upload: Box, + semaphore: Arc, +} + +impl LimitUpload { + /// Create a new [`LimitUpload`] limiting `upload` to `max_concurrency` concurrent requests + pub fn new(upload: Box, max_concurrency: usize) -> Self { + Self { + upload, + semaphore: Arc::new(Semaphore::new(max_concurrency)), + } + } +} + +#[async_trait] +impl MultipartUpload for LimitUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let upload = self.upload.put_part(data); + let s = Arc::clone(&self.semaphore); + Box::pin(async move { + let _permit = s.acquire().await.unwrap(); + upload.await + }) + } + + async fn complete(&mut self) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.upload.complete().await + } + + async fn abort(&mut self) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.upload.abort().await + } +} + +#[cfg(test)] +mod tests { + use crate::integration::*; + use crate::limit::LimitStore; + use crate::memory::InMemory; + use crate::ObjectStore; + use futures::stream::StreamExt; + use std::pin::Pin; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn limit_test() { + let max_requests = 10; + let memory = InMemory::new(); + let integration = LimitStore::new(memory, max_requests); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + stream_get(&integration).await; + + let mut streams = Vec::with_capacity(max_requests); + for _ in 0..max_requests { + let mut stream = integration.list(None).peekable(); + Pin::new(&mut stream).peek().await; // Ensure semaphore is acquired + streams.push(stream); + } + + let t = Duration::from_millis(20); + + // Expect to not be able to make another request + let fut = integration.list(None).collect::>(); + assert!(timeout(t, fut).await.is_err()); + + // Drop one of the streams + streams.pop(); + + // Can now make another request + integration.list(None).collect::>().await; + } +} diff --git a/src/local.rs b/src/local.rs new file mode 100644 index 0000000..ccf6e34 --- /dev/null +++ b/src/local.rs @@ -0,0 +1,1741 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for a local filesystem +use std::fs::{metadata, symlink_metadata, File, Metadata, OpenOptions}; +use std::io::{ErrorKind, Read, Seek, SeekFrom, Write}; +use std::ops::Range; +use std::sync::Arc; +use std::time::SystemTime; +use std::{collections::BTreeSet, io}; +use std::{collections::VecDeque, path::PathBuf}; + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::{stream::BoxStream, StreamExt}; +use futures::{FutureExt, TryStreamExt}; +use parking_lot::Mutex; +use url::Url; +use walkdir::{DirEntry, WalkDir}; + +use crate::{ + maybe_spawn_blocking, + path::{absolute_path_to_url, Path}, + util::InvalidGetRange, + Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMode, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, UploadPart, +}; + +/// A specialized `Error` for filesystem object store-related errors +#[derive(Debug, thiserror::Error)] +pub(crate) enum Error { + #[error("Unable to walk dir: {}", source)] + UnableToWalkDir { source: walkdir::Error }, + + #[error("Unable to access metadata for {}: {}", path, source)] + Metadata { + source: Box, + path: String, + }, + + #[error("Unable to copy data to file: {}", source)] + UnableToCopyDataToFile { source: io::Error }, + + #[error("Unable to rename file: {}", source)] + UnableToRenameFile { source: io::Error }, + + #[error("Unable to create dir {}: {}", path.display(), source)] + UnableToCreateDir { source: io::Error, path: PathBuf }, + + #[error("Unable to create file {}: {}", path.display(), source)] + UnableToCreateFile { source: io::Error, path: PathBuf }, + + #[error("Unable to delete file {}: {}", path.display(), source)] + UnableToDeleteFile { source: io::Error, path: PathBuf }, + + #[error("Unable to open file {}: {}", path.display(), source)] + UnableToOpenFile { source: io::Error, path: PathBuf }, + + #[error("Unable to read data from file {}: {}", path.display(), source)] + UnableToReadBytes { source: io::Error, path: PathBuf }, + + #[error("Out of range of file {}, expected: {}, actual: {}", path.display(), expected, actual)] + OutOfRange { + path: PathBuf, + expected: u64, + actual: u64, + }, + + #[error("Requested range was invalid")] + InvalidRange { source: InvalidGetRange }, + + #[error("Unable to copy file from {} to {}: {}", from.display(), to.display(), source)] + UnableToCopyFile { + from: PathBuf, + to: PathBuf, + source: io::Error, + }, + + #[error("NotFound")] + NotFound { path: PathBuf, source: io::Error }, + + #[error("Error seeking file {}: {}", path.display(), source)] + Seek { source: io::Error, path: PathBuf }, + + #[error("Unable to convert URL \"{}\" to filesystem path", url)] + InvalidUrl { url: Url }, + + #[error("AlreadyExists")] + AlreadyExists { path: String, source: io::Error }, + + #[error("Unable to canonicalize filesystem root: {}", path.display())] + UnableToCanonicalize { path: PathBuf, source: io::Error }, + + #[error("Filenames containing trailing '/#\\d+/' are not supported: {}", path)] + InvalidPath { path: String }, + + #[error("Upload aborted")] + Aborted, +} + +impl From for super::Error { + fn from(source: Error) -> Self { + match source { + Error::NotFound { path, source } => Self::NotFound { + path: path.to_string_lossy().to_string(), + source: source.into(), + }, + Error::AlreadyExists { path, source } => Self::AlreadyExists { + path, + source: source.into(), + }, + _ => Self::Generic { + store: "LocalFileSystem", + source: Box::new(source), + }, + } + } +} + +/// Local filesystem storage providing an [`ObjectStore`] interface to files on +/// local disk. Can optionally be created with a directory prefix +/// +/// # Path Semantics +/// +/// This implementation follows the [file URI] scheme outlined in [RFC 3986]. In +/// particular paths are delimited by `/` +/// +/// [file URI]: https://en.wikipedia.org/wiki/File_URI_scheme +/// [RFC 3986]: https://www.rfc-editor.org/rfc/rfc3986 +/// +/// # Path Semantics +/// +/// [`LocalFileSystem`] will expose the path semantics of the underlying filesystem, which may +/// have additional restrictions beyond those enforced by [`Path`]. +/// +/// For example: +/// +/// * Windows forbids certain filenames, e.g. `COM0`, +/// * Windows forbids folders with trailing `.` +/// * Windows forbids certain ASCII characters, e.g. `<` or `|` +/// * OS X forbids filenames containing `:` +/// * Leading `-` are discouraged on Unix systems where they may be interpreted as CLI flags +/// * Filesystems may have restrictions on the maximum path or path segment length +/// * Filesystem support for non-ASCII characters is inconsistent +/// +/// Additionally some filesystems, such as NTFS, are case-insensitive, whilst others like +/// FAT don't preserve case at all. Further some filesystems support non-unicode character +/// sequences, such as unpaired UTF-16 surrogates, and [`LocalFileSystem`] will error on +/// encountering such sequences. +/// +/// Finally, filenames matching the regex `/.*#\d+/`, e.g. `foo.parquet#123`, are not supported +/// by [`LocalFileSystem`] as they are used to provide atomic writes. Such files will be ignored +/// for listing operations, and attempting to address such a file will error. +/// +/// # Tokio Compatibility +/// +/// Tokio discourages performing blocking IO on a tokio worker thread, however, +/// no major operating systems have stable async file APIs. Therefore if called from +/// a tokio context, this will use [`tokio::runtime::Handle::spawn_blocking`] to dispatch +/// IO to a blocking thread pool, much like `tokio::fs` does under-the-hood. +/// +/// If not called from a tokio context, this will perform IO on the current thread with +/// no additional complexity or overheads +/// +/// # Symlinks +/// +/// [`LocalFileSystem`] will follow symlinks as normal, however, it is worth noting: +/// +/// * Broken symlinks will be silently ignored by listing operations +/// * No effort is made to prevent breaking symlinks when deleting files +/// * Symlinks that resolve to paths outside the root **will** be followed +/// * Mutating a file through one or more symlinks will mutate the underlying file +/// * Deleting a path that resolves to a symlink will only delete the symlink +/// +/// # Cross-Filesystem Copy +/// +/// [`LocalFileSystem::copy`] is implemented using [`std::fs::hard_link`], and therefore +/// does not support copying across filesystem boundaries. +/// +#[derive(Debug)] +pub struct LocalFileSystem { + config: Arc, + // if you want to delete empty directories when deleting files + automatic_cleanup: bool, +} + +#[derive(Debug)] +struct Config { + root: Url, +} + +impl std::fmt::Display for LocalFileSystem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "LocalFileSystem({})", self.config.root) + } +} + +impl Default for LocalFileSystem { + fn default() -> Self { + Self::new() + } +} + +impl LocalFileSystem { + /// Create new filesystem storage with no prefix + pub fn new() -> Self { + Self { + config: Arc::new(Config { + root: Url::parse("file:///").unwrap(), + }), + automatic_cleanup: false, + } + } + + /// Create new filesystem storage with `prefix` applied to all paths + /// + /// Returns an error if the path does not exist + /// + pub fn new_with_prefix(prefix: impl AsRef) -> Result { + let path = std::fs::canonicalize(&prefix).map_err(|source| { + let path = prefix.as_ref().into(); + Error::UnableToCanonicalize { source, path } + })?; + + Ok(Self { + config: Arc::new(Config { + root: absolute_path_to_url(path)?, + }), + automatic_cleanup: false, + }) + } + + /// Return an absolute filesystem path of the given file location + pub fn path_to_filesystem(&self, location: &Path) -> Result { + if !is_valid_file_path(location) { + let path = location.as_ref().into(); + let error = Error::InvalidPath { path }; + return Err(error.into()); + } + + let path = self.config.prefix_to_filesystem(location)?; + + #[cfg(target_os = "windows")] + let path = { + let path = path.to_string_lossy(); + + // Assume the first char is the drive letter and the next is a colon. + let mut out = String::new(); + let drive = &path[..2]; // The drive letter and colon (e.g., "C:") + let filepath = &path[2..].replace(':', "%3A"); // Replace subsequent colons + out.push_str(drive); + out.push_str(filepath); + PathBuf::from(out) + }; + + Ok(path) + } + + /// Enable automatic cleanup of empty directories when deleting files + pub fn with_automatic_cleanup(mut self, automatic_cleanup: bool) -> Self { + self.automatic_cleanup = automatic_cleanup; + self + } +} + +impl Config { + /// Return an absolute filesystem path of the given location + fn prefix_to_filesystem(&self, location: &Path) -> Result { + let mut url = self.root.clone(); + url.path_segments_mut() + .expect("url path") + // technically not necessary as Path ignores empty segments + // but avoids creating paths with "//" which look odd in error messages. + .pop_if_empty() + .extend(location.parts()); + + url.to_file_path() + .map_err(|_| Error::InvalidUrl { url }.into()) + } + + /// Resolves the provided absolute filesystem path to a [`Path`] prefix + fn filesystem_to_path(&self, location: &std::path::Path) -> Result { + Ok(Path::from_absolute_path_with_base( + location, + Some(&self.root), + )?) + } +} + +fn is_valid_file_path(path: &Path) -> bool { + match path.filename() { + Some(p) => match p.split_once('#') { + Some((_, suffix)) if !suffix.is_empty() => { + // Valid if contains non-digits + !suffix.as_bytes().iter().all(|x| x.is_ascii_digit()) + } + _ => true, + }, + None => false, + } +} + +#[async_trait] +impl ObjectStore for LocalFileSystem { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + if matches!(opts.mode, PutMode::Update(_)) { + return Err(crate::Error::NotImplemented); + } + + if !opts.attributes.is_empty() { + return Err(crate::Error::NotImplemented); + } + + let path = self.path_to_filesystem(location)?; + maybe_spawn_blocking(move || { + let (mut file, staging_path) = new_staged_upload(&path)?; + let mut e_tag = None; + + let err = match payload.iter().try_for_each(|x| file.write_all(x)) { + Ok(_) => { + let metadata = file.metadata().map_err(|e| Error::Metadata { + source: e.into(), + path: path.to_string_lossy().to_string(), + })?; + e_tag = Some(get_etag(&metadata)); + match opts.mode { + PutMode::Overwrite => { + // For some fuse types of file systems, the file must be closed first + // to trigger the upload operation, and then renamed, such as Blobfuse + std::mem::drop(file); + match std::fs::rename(&staging_path, &path) { + Ok(_) => None, + Err(source) => Some(Error::UnableToRenameFile { source }), + } + } + PutMode::Create => match std::fs::hard_link(&staging_path, &path) { + Ok(_) => { + let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup + None + } + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => Some(Error::AlreadyExists { + path: path.to_str().unwrap().to_string(), + source, + }), + _ => Some(Error::UnableToRenameFile { source }), + }, + }, + PutMode::Update(_) => unreachable!(), + } + } + Err(source) => Some(Error::UnableToCopyDataToFile { source }), + }; + + if let Some(err) = err { + let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup + return Err(err.into()); + } + + Ok(PutResult { + e_tag, + version: None, + }) + }) + .await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + if !opts.attributes.is_empty() { + return Err(crate::Error::NotImplemented); + } + + let dest = self.path_to_filesystem(location)?; + let (file, src) = new_staged_upload(&dest)?; + Ok(Box::new(LocalUpload::new(src, dest, file))) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + let location = location.clone(); + let path = self.path_to_filesystem(&location)?; + maybe_spawn_blocking(move || { + let (file, metadata) = open_file(&path)?; + let meta = convert_metadata(metadata, location); + options.check_preconditions(&meta)?; + + let range = match options.range { + Some(r) => r + .as_range(meta.size) + .map_err(|source| Error::InvalidRange { source })?, + None => 0..meta.size, + }; + + Ok(GetResult { + payload: GetResultPayload::File(file, path), + attributes: Attributes::default(), + range, + meta, + }) + }) + .await + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let path = self.path_to_filesystem(location)?; + maybe_spawn_blocking(move || { + let (mut file, _) = open_file(&path)?; + read_range(&mut file, &path, range) + }) + .await + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> Result> { + let path = self.path_to_filesystem(location)?; + let ranges = ranges.to_vec(); + maybe_spawn_blocking(move || { + // Vectored IO might be faster + let (mut file, _) = open_file(&path)?; + ranges + .into_iter() + .map(|r| read_range(&mut file, &path, r)) + .collect() + }) + .await + } + + async fn delete(&self, location: &Path) -> Result<()> { + let config = Arc::clone(&self.config); + let path = self.path_to_filesystem(location)?; + let automactic_cleanup = self.automatic_cleanup; + maybe_spawn_blocking(move || { + if let Err(e) = std::fs::remove_file(&path) { + Err(match e.kind() { + ErrorKind::NotFound => Error::NotFound { path, source: e }.into(), + _ => Error::UnableToDeleteFile { path, source: e }.into(), + }) + } else if automactic_cleanup { + let root = &config.root; + let root = root + .to_file_path() + .map_err(|_| Error::InvalidUrl { url: root.clone() })?; + + // here we will try to traverse up and delete an empty dir if possible until we reach the root or get an error + let mut parent = path.parent(); + + while let Some(loc) = parent { + if loc != root && std::fs::remove_dir(loc).is_ok() { + parent = loc.parent(); + } else { + break; + } + } + + Ok(()) + } else { + Ok(()) + } + }) + .await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.list_with_maybe_offset(prefix, None) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + self.list_with_maybe_offset(prefix, Some(offset)) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let config = Arc::clone(&self.config); + + let prefix = prefix.cloned().unwrap_or_default(); + let resolved_prefix = config.prefix_to_filesystem(&prefix)?; + + maybe_spawn_blocking(move || { + let walkdir = WalkDir::new(&resolved_prefix) + .min_depth(1) + .max_depth(1) + .follow_links(true); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + for entry_res in walkdir.into_iter().map(convert_walkdir_result) { + if let Some(entry) = entry_res? { + let is_directory = entry.file_type().is_dir(); + let entry_location = config.filesystem_to_path(entry.path())?; + if !is_directory && !is_valid_file_path(&entry_location) { + continue; + } + + let mut parts = match entry_location.prefix_match(&prefix) { + Some(parts) => parts, + None => continue, + }; + + let common_prefix = match parts.next() { + Some(p) => p, + None => continue, + }; + + drop(parts); + + if is_directory { + common_prefixes.insert(prefix.child(common_prefix)); + } else if let Some(metadata) = convert_entry(entry, entry_location)? { + objects.push(metadata); + } + } + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + }) + .await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let from = self.path_to_filesystem(from)?; + let to = self.path_to_filesystem(to)?; + let mut id = 0; + // In order to make this atomic we: + // + // - hard link to a hidden temporary file + // - atomically rename this temporary file into place + // + // This is necessary because hard_link returns an error if the destination already exists + maybe_spawn_blocking(move || loop { + let staged = staged_upload_path(&to, &id.to_string()); + match std::fs::hard_link(&from, &staged) { + Ok(_) => { + return std::fs::rename(&staged, &to).map_err(|source| { + let _ = std::fs::remove_file(&staged); // Attempt to clean up + Error::UnableToCopyFile { from, to, source }.into() + }); + } + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => id += 1, + ErrorKind::NotFound => match from.exists() { + true => create_parent_dirs(&to, source)?, + false => return Err(Error::NotFound { path: from, source }.into()), + }, + _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), + }, + } + }) + .await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + let from = self.path_to_filesystem(from)?; + let to = self.path_to_filesystem(to)?; + maybe_spawn_blocking(move || loop { + match std::fs::rename(&from, &to) { + Ok(_) => return Ok(()), + Err(source) => match source.kind() { + ErrorKind::NotFound => match from.exists() { + true => create_parent_dirs(&to, source)?, + false => return Err(Error::NotFound { path: from, source }.into()), + }, + _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), + }, + } + }) + .await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let from = self.path_to_filesystem(from)?; + let to = self.path_to_filesystem(to)?; + + maybe_spawn_blocking(move || loop { + match std::fs::hard_link(&from, &to) { + Ok(_) => return Ok(()), + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => { + return Err(Error::AlreadyExists { + path: to.to_str().unwrap().to_string(), + source, + } + .into()) + } + ErrorKind::NotFound => match from.exists() { + true => create_parent_dirs(&to, source)?, + false => return Err(Error::NotFound { path: from, source }.into()), + }, + _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), + }, + } + }) + .await + } +} + +impl LocalFileSystem { + fn list_with_maybe_offset( + &self, + prefix: Option<&Path>, + maybe_offset: Option<&Path>, + ) -> BoxStream<'static, Result> { + let config = Arc::clone(&self.config); + + let root_path = match prefix { + Some(prefix) => match config.prefix_to_filesystem(prefix) { + Ok(path) => path, + Err(e) => return futures::future::ready(Err(e)).into_stream().boxed(), + }, + None => config.root.to_file_path().unwrap(), + }; + + let walkdir = WalkDir::new(root_path) + // Don't include the root directory itself + .min_depth(1) + .follow_links(true); + + let maybe_offset = maybe_offset.cloned(); + + let s = walkdir.into_iter().flat_map(move |result_dir_entry| { + // Apply offset filter before proceeding, to reduce statx file system calls + // This matters for NFS mounts + if let (Some(offset), Ok(entry)) = (maybe_offset.as_ref(), result_dir_entry.as_ref()) { + let location = config.filesystem_to_path(entry.path()); + match location { + Ok(path) if path <= *offset => return None, + Err(e) => return Some(Err(e)), + _ => {} + } + } + + let entry = match convert_walkdir_result(result_dir_entry).transpose()? { + Ok(entry) => entry, + Err(e) => return Some(Err(e)), + }; + + if !entry.path().is_file() { + return None; + } + + match config.filesystem_to_path(entry.path()) { + Ok(path) => match is_valid_file_path(&path) { + true => convert_entry(entry, path).transpose(), + false => None, + }, + Err(e) => Some(Err(e)), + } + }); + + // If no tokio context, return iterator directly as no + // need to perform chunked spawn_blocking reads + if tokio::runtime::Handle::try_current().is_err() { + return futures::stream::iter(s).boxed(); + } + + // Otherwise list in batches of CHUNK_SIZE + const CHUNK_SIZE: usize = 1024; + + let buffer = VecDeque::with_capacity(CHUNK_SIZE); + futures::stream::try_unfold((s, buffer), |(mut s, mut buffer)| async move { + if buffer.is_empty() { + (s, buffer) = tokio::task::spawn_blocking(move || { + for _ in 0..CHUNK_SIZE { + match s.next() { + Some(r) => buffer.push_back(r), + None => break, + } + } + (s, buffer) + }) + .await?; + } + + match buffer.pop_front() { + Some(Err(e)) => Err(e), + Some(Ok(meta)) => Ok(Some((meta, (s, buffer)))), + None => Ok(None), + } + }) + .boxed() + } +} + +/// Creates the parent directories of `path` or returns an error based on `source` if no parent +fn create_parent_dirs(path: &std::path::Path, source: io::Error) -> Result<()> { + let parent = path.parent().ok_or_else(|| { + let path = path.to_path_buf(); + Error::UnableToCreateFile { path, source } + })?; + + std::fs::create_dir_all(parent).map_err(|source| { + let path = parent.into(); + Error::UnableToCreateDir { source, path } + })?; + Ok(()) +} + +/// Generates a unique file path `{base}#{suffix}`, returning the opened `File` and `path` +/// +/// Creates any directories if necessary +fn new_staged_upload(base: &std::path::Path) -> Result<(File, PathBuf)> { + let mut multipart_id = 1; + loop { + let suffix = multipart_id.to_string(); + let path = staged_upload_path(base, &suffix); + let mut options = OpenOptions::new(); + match options.read(true).write(true).create_new(true).open(&path) { + Ok(f) => return Ok((f, path)), + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => multipart_id += 1, + ErrorKind::NotFound => create_parent_dirs(&path, source)?, + _ => return Err(Error::UnableToOpenFile { source, path }.into()), + }, + } + } +} + +/// Returns the unique upload for the given path and suffix +fn staged_upload_path(dest: &std::path::Path, suffix: &str) -> PathBuf { + let mut staging_path = dest.as_os_str().to_owned(); + staging_path.push("#"); + staging_path.push(suffix); + staging_path.into() +} + +#[derive(Debug)] +struct LocalUpload { + /// The upload state + state: Arc, + /// The location of the temporary file + src: Option, + /// The next offset to write into the file + offset: u64, +} + +#[derive(Debug)] +struct UploadState { + dest: PathBuf, + file: Mutex, +} + +impl LocalUpload { + pub(crate) fn new(src: PathBuf, dest: PathBuf, file: File) -> Self { + Self { + state: Arc::new(UploadState { + dest, + file: Mutex::new(file), + }), + src: Some(src), + offset: 0, + } + } +} + +#[async_trait] +impl MultipartUpload for LocalUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let offset = self.offset; + self.offset += data.content_length() as u64; + + let s = Arc::clone(&self.state); + maybe_spawn_blocking(move || { + let mut file = s.file.lock(); + file.seek(SeekFrom::Start(offset)).map_err(|source| { + let path = s.dest.clone(); + Error::Seek { source, path } + })?; + + data.iter() + .try_for_each(|x| file.write_all(x)) + .map_err(|source| Error::UnableToCopyDataToFile { source })?; + + Ok(()) + }) + .boxed() + } + + async fn complete(&mut self) -> Result { + let src = self.src.take().ok_or(Error::Aborted)?; + let s = Arc::clone(&self.state); + maybe_spawn_blocking(move || { + // Ensure no inflight writes + let file = s.file.lock(); + std::fs::rename(&src, &s.dest) + .map_err(|source| Error::UnableToRenameFile { source })?; + let metadata = file.metadata().map_err(|e| Error::Metadata { + source: e.into(), + path: src.to_string_lossy().to_string(), + })?; + + Ok(PutResult { + e_tag: Some(get_etag(&metadata)), + version: None, + }) + }) + .await + } + + async fn abort(&mut self) -> Result<()> { + let src = self.src.take().ok_or(Error::Aborted)?; + maybe_spawn_blocking(move || { + std::fs::remove_file(&src) + .map_err(|source| Error::UnableToDeleteFile { source, path: src })?; + Ok(()) + }) + .await + } +} + +impl Drop for LocalUpload { + fn drop(&mut self) { + if let Some(src) = self.src.take() { + // Try to clean up intermediate file ignoring any error + match tokio::runtime::Handle::try_current() { + Ok(r) => drop(r.spawn_blocking(move || std::fs::remove_file(src))), + Err(_) => drop(std::fs::remove_file(src)), + }; + } + } +} + +pub(crate) fn chunked_stream( + mut file: File, + path: PathBuf, + range: Range, + chunk_size: usize, +) -> BoxStream<'static, Result> { + futures::stream::once(async move { + let (file, path) = maybe_spawn_blocking(move || { + file.seek(SeekFrom::Start(range.start as _)) + .map_err(|source| Error::Seek { + source, + path: path.clone(), + })?; + Ok((file, path)) + }) + .await?; + + let stream = futures::stream::try_unfold( + (file, path, range.end - range.start), + move |(mut file, path, remaining)| { + maybe_spawn_blocking(move || { + if remaining == 0 { + return Ok(None); + } + + let to_read = remaining.min(chunk_size as u64); + let cap = usize::try_from(to_read).map_err(|_e| Error::InvalidRange { + source: InvalidGetRange::TooLarge { + requested: to_read, + max: usize::MAX as u64, + }, + })?; + let mut buffer = Vec::with_capacity(cap); + let read = (&mut file) + .take(to_read) + .read_to_end(&mut buffer) + .map_err(|e| Error::UnableToReadBytes { + source: e, + path: path.clone(), + })?; + + Ok(Some((buffer.into(), (file, path, remaining - read as u64)))) + }) + }, + ); + Ok::<_, super::Error>(stream) + }) + .try_flatten() + .boxed() +} + +pub(crate) fn read_range(file: &mut File, path: &PathBuf, range: Range) -> Result { + let file_metadata = file.metadata().map_err(|e| Error::Metadata { + source: e.into(), + path: path.to_string_lossy().to_string(), + })?; + + // If none of the range is satisfiable we should error, e.g. if the start offset is beyond the + // extents of the file + let file_len = file_metadata.len(); + if range.start >= file_len { + return Err(Error::InvalidRange { + source: InvalidGetRange::StartTooLarge { + requested: range.start, + length: file_len, + }, + } + .into()); + } + + // Don't read past end of file + let to_read = range.end.min(file_len) - range.start; + + file.seek(SeekFrom::Start(range.start)).map_err(|source| { + let path = path.into(); + Error::Seek { source, path } + })?; + + let mut buf = Vec::with_capacity(to_read as usize); + let read = file.take(to_read).read_to_end(&mut buf).map_err(|source| { + let path = path.into(); + Error::UnableToReadBytes { source, path } + })? as u64; + + if read != to_read { + let error = Error::OutOfRange { + path: path.into(), + expected: to_read, + actual: read, + }; + + return Err(error.into()); + } + + Ok(buf.into()) +} + +fn open_file(path: &PathBuf) -> Result<(File, Metadata)> { + let ret = match File::open(path).and_then(|f| Ok((f.metadata()?, f))) { + Err(e) => Err(match e.kind() { + ErrorKind::NotFound => Error::NotFound { + path: path.clone(), + source: e, + }, + _ => Error::UnableToOpenFile { + path: path.clone(), + source: e, + }, + }), + Ok((metadata, file)) => match !metadata.is_dir() { + true => Ok((file, metadata)), + false => Err(Error::NotFound { + path: path.clone(), + source: io::Error::new(ErrorKind::NotFound, "is directory"), + }), + }, + }?; + Ok(ret) +} + +fn convert_entry(entry: DirEntry, location: Path) -> Result> { + match entry.metadata() { + Ok(metadata) => Ok(Some(convert_metadata(metadata, location))), + Err(e) => { + if let Some(io_err) = e.io_error() { + if io_err.kind() == ErrorKind::NotFound { + return Ok(None); + } + } + Err(Error::Metadata { + source: e.into(), + path: location.to_string(), + })? + } + } +} + +fn last_modified(metadata: &Metadata) -> DateTime { + metadata + .modified() + .expect("Modified file time should be supported on this platform") + .into() +} + +fn get_etag(metadata: &Metadata) -> String { + let inode = get_inode(metadata); + let size = metadata.len(); + let mtime = metadata + .modified() + .ok() + .and_then(|mtime| mtime.duration_since(SystemTime::UNIX_EPOCH).ok()) + .unwrap_or_default() + .as_micros(); + + // Use an ETag scheme based on that used by many popular HTTP servers + // + // + format!("{inode:x}-{mtime:x}-{size:x}") +} + +fn convert_metadata(metadata: Metadata, location: Path) -> ObjectMeta { + let last_modified = last_modified(&metadata); + + ObjectMeta { + location, + last_modified, + size: metadata.len(), + e_tag: Some(get_etag(&metadata)), + version: None, + } +} + +#[cfg(unix)] +/// We include the inode when available to yield an ETag more resistant to collisions +/// and as used by popular web servers such as [Apache](https://httpd.apache.org/docs/2.2/mod/core.html#fileetag) +fn get_inode(metadata: &Metadata) -> u64 { + std::os::unix::fs::MetadataExt::ino(metadata) +} + +#[cfg(not(unix))] +/// On platforms where an inode isn't available, fallback to just relying on size and mtime +fn get_inode(_metadata: &Metadata) -> u64 { + 0 +} + +/// Convert walkdir results and converts not-found errors into `None`. +/// Convert broken symlinks to `None`. +fn convert_walkdir_result( + res: std::result::Result, +) -> Result> { + match res { + Ok(entry) => { + // To check for broken symlink: call symlink_metadata() - it does not traverse symlinks); + // if ok: check if entry is symlink; and try to read it by calling metadata(). + match symlink_metadata(entry.path()) { + Ok(attr) => { + if attr.is_symlink() { + let target_metadata = metadata(entry.path()); + match target_metadata { + Ok(_) => { + // symlink is valid + Ok(Some(entry)) + } + Err(_) => { + // this is a broken symlink, return None + Ok(None) + } + } + } else { + Ok(Some(entry)) + } + } + Err(_) => Ok(None), + } + } + + Err(walkdir_err) => match walkdir_err.io_error() { + Some(io_err) => match io_err.kind() { + ErrorKind::NotFound => Ok(None), + _ => Err(Error::UnableToWalkDir { + source: walkdir_err, + } + .into()), + }, + None => Err(Error::UnableToWalkDir { + source: walkdir_err, + } + .into()), + }, + } +} + +#[cfg(test)] +mod tests { + use std::fs; + + use futures::TryStreamExt; + use tempfile::TempDir; + + #[cfg(target_family = "unix")] + use tempfile::NamedTempFile; + + use crate::integration::*; + + use super::*; + + #[tokio::test] + #[cfg(target_family = "unix")] + async fn file_test() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + copy_rename_nonexistent_object(&integration).await; + stream_get(&integration).await; + put_opts(&integration, false).await; + } + + #[test] + #[cfg(target_family = "unix")] + fn test_non_tokio() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + futures::executor::block_on(async move { + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + + // Can't use stream_get test as WriteMultipart uses a tokio JoinSet + let p = Path::from("manual_upload"); + let mut upload = integration.put_multipart(&p).await.unwrap(); + upload.put_part("123".into()).await.unwrap(); + upload.put_part("45678".into()).await.unwrap(); + let r = upload.complete().await.unwrap(); + + let get = integration.get(&p).await.unwrap(); + assert_eq!(get.meta.e_tag.as_ref().unwrap(), r.e_tag.as_ref().unwrap()); + let actual = get.bytes().await.unwrap(); + assert_eq!(actual.as_ref(), b"12345678"); + }); + } + + #[tokio::test] + async fn creates_dir_if_not_present() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("nested/file/test_file"); + + let data = Bytes::from("arbitrary data"); + + integration + .put(&location, data.clone().into()) + .await + .unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, data); + } + + #[tokio::test] + async fn unknown_length() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + + integration + .put(&location, data.clone().into()) + .await + .unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, data); + } + + #[tokio::test] + async fn range_request_start_beyond_end_of_file() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + + integration + .put(&location, data.clone().into()) + .await + .unwrap(); + + integration + .get_range(&location, 100..200) + .await + .expect_err("Should error with start range beyond end of file"); + } + + #[tokio::test] + async fn range_request_beyond_end_of_file() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + + integration + .put(&location, data.clone().into()) + .await + .unwrap(); + + let read_data = integration.get_range(&location, 0..100).await.unwrap(); + assert_eq!(&*read_data, data); + } + + #[tokio::test] + #[cfg(target_family = "unix")] + // Fails on github actions runner (which runs the tests as root) + #[ignore] + async fn bubble_up_io_errors() { + use std::{fs::set_permissions, os::unix::prelude::PermissionsExt}; + + let root = TempDir::new().unwrap(); + + // make non-readable + let metadata = root.path().metadata().unwrap(); + let mut permissions = metadata.permissions(); + permissions.set_mode(0o000); + set_permissions(root.path(), permissions).unwrap(); + + let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let mut stream = store.list(None); + let mut any_err = false; + while let Some(res) = stream.next().await { + if res.is_err() { + any_err = true; + } + } + assert!(any_err); + + // `list_with_delimiter + assert!(store.list_with_delimiter(None).await.is_err()); + } + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + #[tokio::test] + async fn get_nonexistent_location() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from(NON_EXISTENT_NAME); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + if let crate::Error::NotFound { path, source } = err { + let source_variant = source.downcast_ref::(); + assert!( + matches!(source_variant, Some(std::io::Error { .. }),), + "got: {source_variant:?}" + ); + assert!(path.ends_with(NON_EXISTENT_NAME), "{}", path); + } else { + panic!("unexpected error type: {err:?}"); + } + } + + #[tokio::test] + async fn root() { + let integration = LocalFileSystem::new(); + + let canonical = std::path::Path::new("Cargo.toml").canonicalize().unwrap(); + let url = Url::from_directory_path(&canonical).unwrap(); + let path = Path::parse(url.path()).unwrap(); + + let roundtrip = integration.path_to_filesystem(&path).unwrap(); + + // Needed as on Windows canonicalize returns extended length path syntax + // C:\Users\circleci -> \\?\C:\Users\circleci + let roundtrip = roundtrip.canonicalize().unwrap(); + + assert_eq!(roundtrip, canonical); + + integration.head(&path).await.unwrap(); + } + + #[tokio::test] + #[cfg(target_family = "windows")] + async fn test_list_root() { + let fs = LocalFileSystem::new(); + let r = fs.list_with_delimiter(None).await.unwrap_err().to_string(); + + assert!( + r.contains("Unable to convert URL \"file:///\" to filesystem path"), + "{}", + r + ); + } + + #[tokio::test] + #[cfg(target_os = "linux")] + async fn test_list_root() { + let fs = LocalFileSystem::new(); + fs.list_with_delimiter(None).await.unwrap(); + } + + #[cfg(target_family = "unix")] + async fn check_list(integration: &LocalFileSystem, prefix: Option<&Path>, expected: &[&str]) { + let result: Vec<_> = integration.list(prefix).try_collect().await.unwrap(); + + let mut strings: Vec<_> = result.iter().map(|x| x.location.as_ref()).collect(); + strings.sort_unstable(); + assert_eq!(&strings, expected) + } + + #[tokio::test] + #[cfg(target_family = "unix")] + async fn test_symlink() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let subdir = root.path().join("a"); + std::fs::create_dir(&subdir).unwrap(); + let file = subdir.join("file.parquet"); + std::fs::write(file, "test").unwrap(); + + check_list(&integration, None, &["a/file.parquet"]).await; + integration + .head(&Path::from("a/file.parquet")) + .await + .unwrap(); + + // Follow out of tree symlink + let other = NamedTempFile::new().unwrap(); + std::os::unix::fs::symlink(other.path(), root.path().join("test.parquet")).unwrap(); + + // Should return test.parquet even though out of tree + check_list(&integration, None, &["a/file.parquet", "test.parquet"]).await; + + // Can fetch test.parquet + integration.head(&Path::from("test.parquet")).await.unwrap(); + + // Follow in tree symlink + std::os::unix::fs::symlink(&subdir, root.path().join("b")).unwrap(); + check_list( + &integration, + None, + &["a/file.parquet", "b/file.parquet", "test.parquet"], + ) + .await; + check_list(&integration, Some(&Path::from("b")), &["b/file.parquet"]).await; + + // Can fetch through symlink + integration + .head(&Path::from("b/file.parquet")) + .await + .unwrap(); + + // Ignore broken symlink + std::os::unix::fs::symlink(root.path().join("foo.parquet"), root.path().join("c")).unwrap(); + + check_list( + &integration, + None, + &["a/file.parquet", "b/file.parquet", "test.parquet"], + ) + .await; + + let mut r = integration.list_with_delimiter(None).await.unwrap(); + r.common_prefixes.sort_unstable(); + assert_eq!(r.common_prefixes.len(), 2); + assert_eq!(r.common_prefixes[0].as_ref(), "a"); + assert_eq!(r.common_prefixes[1].as_ref(), "b"); + assert_eq!(r.objects.len(), 1); + assert_eq!(r.objects[0].location.as_ref(), "test.parquet"); + + let r = integration + .list_with_delimiter(Some(&Path::from("a"))) + .await + .unwrap(); + assert_eq!(r.common_prefixes.len(), 0); + assert_eq!(r.objects.len(), 1); + assert_eq!(r.objects[0].location.as_ref(), "a/file.parquet"); + + // Deleting a symlink doesn't delete the source file + integration + .delete(&Path::from("test.parquet")) + .await + .unwrap(); + assert!(other.path().exists()); + + check_list(&integration, None, &["a/file.parquet", "b/file.parquet"]).await; + + // Deleting through a symlink deletes both files + integration + .delete(&Path::from("b/file.parquet")) + .await + .unwrap(); + + check_list(&integration, None, &[]).await; + + // Adding a file through a symlink creates in both paths + integration + .put(&Path::from("b/file.parquet"), vec![0, 1, 2].into()) + .await + .unwrap(); + + check_list(&integration, None, &["a/file.parquet", "b/file.parquet"]).await; + } + + #[tokio::test] + async fn invalid_path() { + let root = TempDir::new().unwrap(); + let root = root.path().join("🙀"); + std::fs::create_dir(root.clone()).unwrap(); + + // Invalid paths supported above root of store + let integration = LocalFileSystem::new_with_prefix(root.clone()).unwrap(); + + let directory = Path::from("directory"); + let object = directory.child("child.txt"); + let data = Bytes::from("arbitrary"); + integration.put(&object, data.clone().into()).await.unwrap(); + integration.head(&object).await.unwrap(); + let result = integration.get(&object).await.unwrap(); + assert_eq!(result.bytes().await.unwrap(), data); + + flatten_list_stream(&integration, None).await.unwrap(); + flatten_list_stream(&integration, Some(&directory)) + .await + .unwrap(); + + let result = integration + .list_with_delimiter(Some(&directory)) + .await + .unwrap(); + assert_eq!(result.objects.len(), 1); + assert!(result.common_prefixes.is_empty()); + assert_eq!(result.objects[0].location, object); + + let emoji = root.join("💀"); + std::fs::write(emoji, "foo").unwrap(); + + // Can list illegal file + let mut paths = flatten_list_stream(&integration, None).await.unwrap(); + paths.sort_unstable(); + + assert_eq!( + paths, + vec![ + Path::parse("directory/child.txt").unwrap(), + Path::parse("💀").unwrap() + ] + ); + } + + #[tokio::test] + async fn list_hides_incomplete_uploads() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let location = Path::from("some_file"); + + let data = PutPayload::from("arbitrary data"); + let mut u1 = integration.put_multipart(&location).await.unwrap(); + u1.put_part(data.clone()).await.unwrap(); + + let mut u2 = integration.put_multipart(&location).await.unwrap(); + u2.put_part(data).await.unwrap(); + + let list = flatten_list_stream(&integration, None).await.unwrap(); + assert_eq!(list.len(), 0); + + assert_eq!( + integration + .list_with_delimiter(None) + .await + .unwrap() + .objects + .len(), + 0 + ); + } + + #[tokio::test] + async fn test_path_with_offset() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let root_path = root.path(); + for i in 0..5 { + let filename = format!("test{}.parquet", i); + let file = root_path.join(filename); + std::fs::write(file, "test").unwrap(); + } + let filter_str = "test"; + let filter = String::from(filter_str); + let offset_str = filter + "1"; + let offset = Path::from(offset_str.clone()); + + // Use list_with_offset to retrieve files + let res = integration.list_with_offset(None, &offset); + let offset_paths: Vec<_> = res.map_ok(|x| x.location).try_collect().await.unwrap(); + let mut offset_files: Vec<_> = offset_paths + .iter() + .map(|x| String::from(x.filename().unwrap())) + .collect(); + + // Check result with direct filesystem read + let files = fs::read_dir(root_path).unwrap(); + let filtered_files = files + .filter_map(Result::ok) + .filter_map(|d| { + d.file_name().to_str().and_then(|f| { + if f.contains(filter_str) { + Some(String::from(f)) + } else { + None + } + }) + }) + .collect::>(); + + let mut expected_offset_files: Vec<_> = filtered_files + .iter() + .filter(|s| **s > offset_str) + .cloned() + .collect(); + + fn do_vecs_match(a: &[T], b: &[T]) -> bool { + let matching = a.iter().zip(b.iter()).filter(|&(a, b)| a == b).count(); + matching == a.len() && matching == b.len() + } + + offset_files.sort(); + expected_offset_files.sort(); + + // println!("Expected Offset Files: {:?}", expected_offset_files); + // println!("Actual Offset Files: {:?}", offset_files); + + assert_eq!(offset_files.len(), expected_offset_files.len()); + assert!(do_vecs_match(&expected_offset_files, &offset_files)); + } + + #[tokio::test] + async fn filesystem_filename_with_percent() { + let temp_dir = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(temp_dir.path()).unwrap(); + let filename = "L%3ABC.parquet"; + + std::fs::write(temp_dir.path().join(filename), "foo").unwrap(); + + let res: Vec<_> = integration.list(None).try_collect().await.unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res[0].location.as_ref(), filename); + + let res = integration.list_with_delimiter(None).await.unwrap(); + assert_eq!(res.objects.len(), 1); + assert_eq!(res.objects[0].location.as_ref(), filename); + } + + #[tokio::test] + async fn relative_paths() { + LocalFileSystem::new_with_prefix(".").unwrap(); + LocalFileSystem::new_with_prefix("..").unwrap(); + LocalFileSystem::new_with_prefix("../..").unwrap(); + + let integration = LocalFileSystem::new(); + let path = Path::from_filesystem_path(".").unwrap(); + integration.list_with_delimiter(Some(&path)).await.unwrap(); + } + + #[test] + fn test_valid_path() { + let cases = [ + ("foo#123/test.txt", true), + ("foo#123/test#23.txt", true), + ("foo#123/test#34", false), + ("foo😁/test#34", false), + ("foo/test#😁34", true), + ]; + + for (case, expected) in cases { + let path = Path::parse(case).unwrap(); + assert_eq!(is_valid_file_path(&path), expected); + } + } + + #[tokio::test] + async fn test_intermediate_files() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let a = Path::parse("foo#123/test.txt").unwrap(); + integration.put(&a, "test".into()).await.unwrap(); + + let list = flatten_list_stream(&integration, None).await.unwrap(); + assert_eq!(list, vec![a.clone()]); + + std::fs::write(root.path().join("bar#123"), "test").unwrap(); + + // Should ignore file + let list = flatten_list_stream(&integration, None).await.unwrap(); + assert_eq!(list, vec![a.clone()]); + + let b = Path::parse("bar#123").unwrap(); + let err = integration.get(&b).await.unwrap_err().to_string(); + assert_eq!(err, "Generic LocalFileSystem error: Filenames containing trailing '/#\\d+/' are not supported: bar#123"); + + let c = Path::parse("foo#123.txt").unwrap(); + integration.put(&c, "test".into()).await.unwrap(); + + let mut list = flatten_list_stream(&integration, None).await.unwrap(); + list.sort_unstable(); + assert_eq!(list, vec![c, a]); + } + + #[tokio::test] + #[cfg(target_os = "windows")] + async fn filesystem_filename_with_colon() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let path = Path::parse("file%3Aname.parquet").unwrap(); + let location = Path::parse("file:name.parquet").unwrap(); + + integration.put(&location, "test".into()).await.unwrap(); + let list = flatten_list_stream(&integration, None).await.unwrap(); + assert_eq!(list, vec![path.clone()]); + + let result = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(result, Bytes::from("test")); + } + + #[tokio::test] + async fn delete_dirs_automatically() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()) + .unwrap() + .with_automatic_cleanup(true); + let location = Path::from("nested/file/test_file"); + let data = Bytes::from("arbitrary data"); + + integration + .put(&location, data.clone().into()) + .await + .unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + + assert_eq!(&*read_data, data); + assert!(fs::read_dir(root.path()).unwrap().count() > 0); + integration.delete(&location).await.unwrap(); + assert!(fs::read_dir(root.path()).unwrap().count() == 0); + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[cfg(test)] +mod not_wasm_tests { + use std::time::Duration; + use tempfile::TempDir; + + use crate::local::LocalFileSystem; + use crate::{ObjectStore, Path, PutPayload}; + + #[tokio::test] + async fn test_cleanup_intermediate_files() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("some_file"); + let data = PutPayload::from_static(b"hello"); + let mut upload = integration.put_multipart(&location).await.unwrap(); + upload.put_part(data).await.unwrap(); + + let file_count = std::fs::read_dir(root.path()).unwrap().count(); + assert_eq!(file_count, 1); + drop(upload); + + for _ in 0..100 { + tokio::time::sleep(Duration::from_millis(1)).await; + let file_count = std::fs::read_dir(root.path()).unwrap().count(); + if file_count == 0 { + return; + } + } + panic!("Failed to cleanup file in 100ms") + } +} + +#[cfg(target_family = "unix")] +#[cfg(test)] +mod unix_test { + use std::fs::OpenOptions; + + use nix::sys::stat; + use nix::unistd; + use tempfile::TempDir; + + use crate::local::LocalFileSystem; + use crate::{ObjectStore, Path}; + + #[tokio::test] + async fn test_fifo() { + let filename = "some_file"; + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let path = root.path().join(filename); + unistd::mkfifo(&path, stat::Mode::S_IRWXU).unwrap(); + + // Need to open read and write side in parallel + let spawned = + tokio::task::spawn_blocking(|| OpenOptions::new().write(true).open(path).unwrap()); + + let location = Path::from(filename); + integration.head(&location).await.unwrap(); + integration.get(&location).await.unwrap(); + + spawned.await.unwrap(); + } +} diff --git a/src/memory.rs b/src/memory.rs new file mode 100644 index 0000000..f03dbc6 --- /dev/null +++ b/src/memory.rs @@ -0,0 +1,630 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An in-memory object store implementation +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::ops::Range; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::{stream::BoxStream, StreamExt}; +use parking_lot::RwLock; + +use crate::multipart::{MultipartStore, PartId}; +use crate::util::InvalidGetRange; +use crate::{ + path::Path, Attributes, GetRange, GetResult, GetResultPayload, ListResult, MultipartId, + MultipartUpload, ObjectMeta, ObjectStore, PutMode, PutMultipartOpts, PutOptions, PutResult, + Result, UpdateVersion, UploadPart, +}; +use crate::{GetOptions, PutPayload}; + +/// A specialized `Error` for in-memory object store-related errors +#[derive(Debug, thiserror::Error)] +enum Error { + #[error("No data in memory found. Location: {path}")] + NoDataInMemory { path: String }, + + #[error("Invalid range: {source}")] + Range { source: InvalidGetRange }, + + #[error("Object already exists at that location: {path}")] + AlreadyExists { path: String }, + + #[error("ETag required for conditional update")] + MissingETag, + + #[error("MultipartUpload not found: {id}")] + UploadNotFound { id: String }, + + #[error("Missing part at index: {part}")] + MissingPart { part: usize }, +} + +impl From for super::Error { + fn from(source: Error) -> Self { + match source { + Error::NoDataInMemory { ref path } => Self::NotFound { + path: path.into(), + source: source.into(), + }, + Error::AlreadyExists { ref path } => Self::AlreadyExists { + path: path.into(), + source: source.into(), + }, + _ => Self::Generic { + store: "InMemory", + source: Box::new(source), + }, + } + } +} + +/// In-memory storage suitable for testing or for opting out of using a cloud +/// storage provider. +#[derive(Debug, Default)] +pub struct InMemory { + storage: SharedStorage, +} + +#[derive(Debug, Clone)] +struct Entry { + data: Bytes, + last_modified: DateTime, + attributes: Attributes, + e_tag: usize, +} + +impl Entry { + fn new( + data: Bytes, + last_modified: DateTime, + e_tag: usize, + attributes: Attributes, + ) -> Self { + Self { + data, + last_modified, + e_tag, + attributes, + } + } +} + +#[derive(Debug, Default, Clone)] +struct Storage { + next_etag: usize, + map: BTreeMap, + uploads: HashMap, +} + +#[derive(Debug, Default, Clone)] +struct PartStorage { + parts: Vec>, +} + +type SharedStorage = Arc>; + +impl Storage { + fn insert(&mut self, location: &Path, bytes: Bytes, attributes: Attributes) -> usize { + let etag = self.next_etag; + self.next_etag += 1; + let entry = Entry::new(bytes, Utc::now(), etag, attributes); + self.overwrite(location, entry); + etag + } + + fn overwrite(&mut self, location: &Path, entry: Entry) { + self.map.insert(location.clone(), entry); + } + + fn create(&mut self, location: &Path, entry: Entry) -> Result<()> { + use std::collections::btree_map; + match self.map.entry(location.clone()) { + btree_map::Entry::Occupied(_) => Err(Error::AlreadyExists { + path: location.to_string(), + } + .into()), + btree_map::Entry::Vacant(v) => { + v.insert(entry); + Ok(()) + } + } + } + + fn update(&mut self, location: &Path, v: UpdateVersion, entry: Entry) -> Result<()> { + match self.map.get_mut(location) { + // Return Precondition instead of NotFound for consistency with stores + None => Err(crate::Error::Precondition { + path: location.to_string(), + source: format!("Object at location {location} not found").into(), + }), + Some(e) => { + let existing = e.e_tag.to_string(); + let expected = v.e_tag.ok_or(Error::MissingETag)?; + if existing == expected { + *e = entry; + Ok(()) + } else { + Err(crate::Error::Precondition { + path: location.to_string(), + source: format!("{existing} does not match {expected}").into(), + }) + } + } + } + } + + fn upload_mut(&mut self, id: &MultipartId) -> Result<&mut PartStorage> { + let parts = id + .parse() + .ok() + .and_then(|x| self.uploads.get_mut(&x)) + .ok_or_else(|| Error::UploadNotFound { id: id.into() })?; + Ok(parts) + } + + fn remove_upload(&mut self, id: &MultipartId) -> Result { + let parts = id + .parse() + .ok() + .and_then(|x| self.uploads.remove(&x)) + .ok_or_else(|| Error::UploadNotFound { id: id.into() })?; + Ok(parts) + } +} + +impl std::fmt::Display for InMemory { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "InMemory") + } +} + +#[async_trait] +impl ObjectStore for InMemory { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let mut storage = self.storage.write(); + let etag = storage.next_etag; + let entry = Entry::new(payload.into(), Utc::now(), etag, opts.attributes); + + match opts.mode { + PutMode::Overwrite => storage.overwrite(location, entry), + PutMode::Create => storage.create(location, entry)?, + PutMode::Update(v) => storage.update(location, v, entry)?, + } + storage.next_etag += 1; + + Ok(PutResult { + e_tag: Some(etag.to_string()), + version: None, + }) + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + Ok(Box::new(InMemoryUpload { + location: location.clone(), + attributes: opts.attributes, + parts: vec![], + storage: Arc::clone(&self.storage), + })) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + let entry = self.entry(location)?; + let e_tag = entry.e_tag.to_string(); + + let meta = ObjectMeta { + location: location.clone(), + last_modified: entry.last_modified, + size: entry.data.len() as u64, + e_tag: Some(e_tag), + version: None, + }; + options.check_preconditions(&meta)?; + + let (range, data) = match options.range { + Some(range) => { + let r = range + .as_range(entry.data.len() as u64) + .map_err(|source| Error::Range { source })?; + ( + r.clone(), + entry.data.slice(r.start as usize..r.end as usize), + ) + } + None => (0..entry.data.len() as u64, entry.data), + }; + let stream = futures::stream::once(futures::future::ready(Ok(data))); + + Ok(GetResult { + payload: GetResultPayload::Stream(stream.boxed()), + attributes: entry.attributes, + meta, + range, + }) + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> Result> { + let entry = self.entry(location)?; + ranges + .iter() + .map(|range| { + let r = GetRange::Bounded(range.clone()) + .as_range(entry.data.len() as u64) + .map_err(|source| Error::Range { source })?; + let r_end = usize::try_from(r.end).map_err(|_e| Error::Range { + source: InvalidGetRange::TooLarge { + requested: r.end, + max: usize::MAX as u64, + }, + })?; + let r_start = usize::try_from(r.start).map_err(|_e| Error::Range { + source: InvalidGetRange::TooLarge { + requested: r.start, + max: usize::MAX as u64, + }, + })?; + Ok(entry.data.slice(r_start..r_end)) + }) + .collect() + } + + async fn head(&self, location: &Path) -> Result { + let entry = self.entry(location)?; + + Ok(ObjectMeta { + location: location.clone(), + last_modified: entry.last_modified, + size: entry.data.len() as u64, + e_tag: Some(entry.e_tag.to_string()), + version: None, + }) + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.storage.write().map.remove(location); + Ok(()) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let root = Path::default(); + let prefix = prefix.unwrap_or(&root); + + let storage = self.storage.read(); + let values: Vec<_> = storage + .map + .range((prefix)..) + .take_while(|(key, _)| key.as_ref().starts_with(prefix.as_ref())) + .filter(|(key, _)| { + // Don't return for exact prefix match + key.prefix_match(prefix) + .map(|mut x| x.next().is_some()) + .unwrap_or(false) + }) + .map(|(key, value)| { + Ok(ObjectMeta { + location: key.clone(), + last_modified: value.last_modified, + size: value.data.len() as u64, + e_tag: Some(value.e_tag.to_string()), + version: None, + }) + }) + .collect(); + + futures::stream::iter(values).boxed() + } + + /// The memory implementation returns all results, as opposed to the cloud + /// versions which limit their results to 1k or more because of API + /// limitations. + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let root = Path::default(); + let prefix = prefix.unwrap_or(&root); + + let mut common_prefixes = BTreeSet::new(); + + // Only objects in this base level should be returned in the + // response. Otherwise, we just collect the common prefixes. + let mut objects = vec![]; + for (k, v) in self.storage.read().map.range((prefix)..) { + if !k.as_ref().starts_with(prefix.as_ref()) { + break; + } + + let mut parts = match k.prefix_match(prefix) { + Some(parts) => parts, + None => continue, + }; + + // Pop first element + let common_prefix = match parts.next() { + Some(p) => p, + // Should only return children of the prefix + None => continue, + }; + + if parts.next().is_some() { + common_prefixes.insert(prefix.child(common_prefix)); + } else { + let object = ObjectMeta { + location: k.clone(), + last_modified: v.last_modified, + size: v.data.len() as u64, + e_tag: Some(v.e_tag.to_string()), + version: None, + }; + objects.push(object); + } + } + + Ok(ListResult { + objects, + common_prefixes: common_prefixes.into_iter().collect(), + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let entry = self.entry(from)?; + self.storage + .write() + .insert(to, entry.data, entry.attributes); + Ok(()) + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let entry = self.entry(from)?; + let mut storage = self.storage.write(); + if storage.map.contains_key(to) { + return Err(Error::AlreadyExists { + path: to.to_string(), + } + .into()); + } + storage.insert(to, entry.data, entry.attributes); + Ok(()) + } +} + +#[async_trait] +impl MultipartStore for InMemory { + async fn create_multipart(&self, _path: &Path) -> Result { + let mut storage = self.storage.write(); + let etag = storage.next_etag; + storage.next_etag += 1; + storage.uploads.insert(etag, Default::default()); + Ok(etag.to_string()) + } + + async fn put_part( + &self, + _path: &Path, + id: &MultipartId, + part_idx: usize, + payload: PutPayload, + ) -> Result { + let mut storage = self.storage.write(); + let upload = storage.upload_mut(id)?; + if part_idx <= upload.parts.len() { + upload.parts.resize(part_idx + 1, None); + } + upload.parts[part_idx] = Some(payload.into()); + Ok(PartId { + content_id: Default::default(), + }) + } + + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + _parts: Vec, + ) -> Result { + let mut storage = self.storage.write(); + let upload = storage.remove_upload(id)?; + + let mut cap = 0; + for (part, x) in upload.parts.iter().enumerate() { + cap += x.as_ref().ok_or(Error::MissingPart { part })?.len(); + } + let mut buf = Vec::with_capacity(cap); + for x in &upload.parts { + buf.extend_from_slice(x.as_ref().unwrap()) + } + let etag = storage.insert(path, buf.into(), Default::default()); + Ok(PutResult { + e_tag: Some(etag.to_string()), + version: None, + }) + } + + async fn abort_multipart(&self, _path: &Path, id: &MultipartId) -> Result<()> { + self.storage.write().remove_upload(id)?; + Ok(()) + } +} + +impl InMemory { + /// Create new in-memory storage. + pub fn new() -> Self { + Self::default() + } + + /// Creates a fork of the store, with the current content copied into the + /// new store. + pub fn fork(&self) -> Self { + let storage = self.storage.read(); + let storage = Arc::new(RwLock::new(storage.clone())); + Self { storage } + } + + fn entry(&self, location: &Path) -> Result { + let storage = self.storage.read(); + let value = storage + .map + .get(location) + .cloned() + .ok_or_else(|| Error::NoDataInMemory { + path: location.to_string(), + })?; + + Ok(value) + } +} + +#[derive(Debug)] +struct InMemoryUpload { + location: Path, + attributes: Attributes, + parts: Vec, + storage: Arc>, +} + +#[async_trait] +impl MultipartUpload for InMemoryUpload { + fn put_part(&mut self, payload: PutPayload) -> UploadPart { + self.parts.push(payload); + Box::pin(futures::future::ready(Ok(()))) + } + + async fn complete(&mut self) -> Result { + let cap = self.parts.iter().map(|x| x.content_length()).sum(); + let mut buf = Vec::with_capacity(cap); + let parts = self.parts.iter().flatten(); + parts.for_each(|x| buf.extend_from_slice(x)); + let etag = self.storage.write().insert( + &self.location, + buf.into(), + std::mem::take(&mut self.attributes), + ); + + Ok(PutResult { + e_tag: Some(etag.to_string()), + version: None, + }) + } + + async fn abort(&mut self) -> Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::integration::*; + + use super::*; + + #[tokio::test] + async fn in_memory_test() { + let integration = InMemory::new(); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + put_opts(&integration, true).await; + multipart(&integration, &integration).await; + put_get_attributes(&integration).await; + } + + #[tokio::test] + async fn box_test() { + let integration: Box = Box::new(InMemory::new()); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn arc_test() { + let integration: Arc = Arc::new(InMemory::new()); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn unknown_length() { + let integration = InMemory::new(); + + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + + integration + .put(&location, data.clone().into()) + .await + .unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, data); + } + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + #[tokio::test] + async fn nonexistent_location() { + let integration = InMemory::new(); + + let location = Path::from(NON_EXISTENT_NAME); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + if let crate::Error::NotFound { path, source } = err { + let source_variant = source.downcast_ref::(); + assert!( + matches!(source_variant, Some(Error::NoDataInMemory { .. }),), + "got: {source_variant:?}" + ); + assert_eq!(path, NON_EXISTENT_NAME); + } else { + panic!("unexpected error type: {err:?}"); + } + } +} diff --git a/src/multipart.rs b/src/multipart.rs new file mode 100644 index 0000000..d94e7f1 --- /dev/null +++ b/src/multipart.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Cloud Multipart Upload +//! +//! This crate provides an asynchronous interface for multipart file uploads to +//! cloud storage services. It's designed to offer efficient, non-blocking operations, +//! especially useful when dealing with large files or high-throughput systems. + +use async_trait::async_trait; + +use crate::path::Path; +use crate::{MultipartId, PutPayload, PutResult, Result}; + +/// Represents a part of a file that has been successfully uploaded in a multipart upload process. +#[derive(Debug, Clone)] +pub struct PartId { + /// Id of this part + pub content_id: String, +} + +/// A low-level interface for interacting with multipart upload APIs +/// +/// Most use-cases should prefer [`ObjectStore::put_multipart`] as this is supported by more +/// backends, including [`LocalFileSystem`], and automatically handles uploading fixed +/// size parts of sufficient size in parallel +/// +/// [`ObjectStore::put_multipart`]: crate::ObjectStore::put_multipart +/// [`LocalFileSystem`]: crate::local::LocalFileSystem +#[async_trait] +pub trait MultipartStore: Send + Sync + 'static { + /// Creates a new multipart upload, returning the [`MultipartId`] + async fn create_multipart(&self, path: &Path) -> Result; + + /// Uploads a new part with index `part_idx` + /// + /// `part_idx` should be an integer in the range `0..N` where `N` is the number of + /// parts in the upload. Parts may be uploaded concurrently and in any order. + /// + /// Most stores require that all parts excluding the last are at least 5 MiB, and some + /// further require that all parts excluding the last be the same size, e.g. [R2]. + /// [`WriteMultipart`] performs writes in fixed size blocks of 5 MiB, and clients wanting + /// to maximise compatibility should look to do likewise. + /// + /// [R2]: https://developers.cloudflare.com/r2/objects/multipart-objects/#limitations + /// [`WriteMultipart`]: crate::upload::WriteMultipart + async fn put_part( + &self, + path: &Path, + id: &MultipartId, + part_idx: usize, + data: PutPayload, + ) -> Result; + + /// Completes a multipart upload + /// + /// The `i`'th value of `parts` must be a [`PartId`] returned by a call to [`Self::put_part`] + /// with a `part_idx` of `i`, and the same `path` and `id` as provided to this method. Calling + /// this method with out of sequence or repeated [`PartId`], or [`PartId`] returned for other + /// values of `path` or `id`, will result in implementation-defined behaviour + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + parts: Vec, + ) -> Result; + + /// Aborts a multipart upload + async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()>; +} diff --git a/src/parse.rs b/src/parse.rs new file mode 100644 index 0000000..00ea6cf --- /dev/null +++ b/src/parse.rs @@ -0,0 +1,373 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] +use crate::local::LocalFileSystem; +use crate::memory::InMemory; +use crate::path::Path; +use crate::ObjectStore; +use url::Url; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Unable to recognise URL \"{}\"", url)] + Unrecognised { url: Url }, + + #[error(transparent)] + Path { + #[from] + source: crate::path::Error, + }, +} + +impl From for super::Error { + fn from(e: Error) -> Self { + Self::Generic { + store: "URL", + source: Box::new(e), + } + } +} + +/// Recognizes various URL formats, identifying the relevant [`ObjectStore`] +/// +/// See [`ObjectStoreScheme::parse`] for more details +/// +/// # Supported formats: +/// - `file:///path/to/my/file` -> [`LocalFileSystem`] +/// - `memory:///` -> [`InMemory`] +/// - `s3://bucket/path` -> [`AmazonS3`](crate::aws::AmazonS3) (also supports `s3a`) +/// - `gs://bucket/path` -> [`GoogleCloudStorage`](crate::gcp::GoogleCloudStorage) +/// - `az://account/container/path` -> [`MicrosoftAzure`](crate::azure::MicrosoftAzure) (also supports `adl`, `azure`, `abfs`, `abfss`) +/// - `http://mydomain/path` -> [`HttpStore`](crate::http::HttpStore) +/// - `https://mydomain/path` -> [`HttpStore`](crate::http::HttpStore) +/// +/// There are also special cases for AWS and Azure for `https://{host?}/path` paths: +/// - `dfs.core.windows.net`, `blob.core.windows.net`, `dfs.fabric.microsoft.com`, `blob.fabric.microsoft.com` -> [`MicrosoftAzure`](crate::azure::MicrosoftAzure) +/// - `amazonaws.com` -> [`AmazonS3`](crate::aws::AmazonS3) +/// - `r2.cloudflarestorage.com` -> [`AmazonS3`](crate::aws::AmazonS3) +/// +#[non_exhaustive] // permit new variants +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum ObjectStoreScheme { + /// Url corresponding to [`LocalFileSystem`] + Local, + /// Url corresponding to [`InMemory`] + Memory, + /// Url corresponding to [`AmazonS3`](crate::aws::AmazonS3) + AmazonS3, + /// Url corresponding to [`GoogleCloudStorage`](crate::gcp::GoogleCloudStorage) + GoogleCloudStorage, + /// Url corresponding to [`MicrosoftAzure`](crate::azure::MicrosoftAzure) + MicrosoftAzure, + /// Url corresponding to [`HttpStore`](crate::http::HttpStore) + Http, +} + +impl ObjectStoreScheme { + /// Create an [`ObjectStoreScheme`] from the provided [`Url`] + /// + /// Returns the [`ObjectStoreScheme`] and the remaining [`Path`] + /// + /// # Example + /// ``` + /// # use url::Url; + /// # use object_store::ObjectStoreScheme; + /// let url: Url = "file:///path/to/my/file".parse().unwrap(); + /// let (scheme, path) = ObjectStoreScheme::parse(&url).unwrap(); + /// assert_eq!(scheme, ObjectStoreScheme::Local); + /// assert_eq!(path.as_ref(), "path/to/my/file"); + /// + /// let url: Url = "https://blob.core.windows.net/path/to/my/file".parse().unwrap(); + /// let (scheme, path) = ObjectStoreScheme::parse(&url).unwrap(); + /// assert_eq!(scheme, ObjectStoreScheme::MicrosoftAzure); + /// assert_eq!(path.as_ref(), "path/to/my/file"); + /// + /// let url: Url = "https://example.com/path/to/my/file".parse().unwrap(); + /// let (scheme, path) = ObjectStoreScheme::parse(&url).unwrap(); + /// assert_eq!(scheme, ObjectStoreScheme::Http); + /// assert_eq!(path.as_ref(), "path/to/my/file"); + /// ``` + pub fn parse(url: &Url) -> Result<(Self, Path), Error> { + let strip_bucket = || Some(url.path().strip_prefix('/')?.split_once('/')?.1); + + let (scheme, path) = match (url.scheme(), url.host_str()) { + ("file", None) => (Self::Local, url.path()), + ("memory", None) => (Self::Memory, url.path()), + ("s3" | "s3a", Some(_)) => (Self::AmazonS3, url.path()), + ("gs", Some(_)) => (Self::GoogleCloudStorage, url.path()), + ("az" | "adl" | "azure" | "abfs" | "abfss", Some(_)) => { + (Self::MicrosoftAzure, url.path()) + } + ("http", Some(_)) => (Self::Http, url.path()), + ("https", Some(host)) => { + if host.ends_with("dfs.core.windows.net") + || host.ends_with("blob.core.windows.net") + || host.ends_with("dfs.fabric.microsoft.com") + || host.ends_with("blob.fabric.microsoft.com") + { + (Self::MicrosoftAzure, url.path()) + } else if host.ends_with("amazonaws.com") { + match host.starts_with("s3") { + true => (Self::AmazonS3, strip_bucket().unwrap_or_default()), + false => (Self::AmazonS3, url.path()), + } + } else if host.ends_with("r2.cloudflarestorage.com") { + (Self::AmazonS3, strip_bucket().unwrap_or_default()) + } else { + (Self::Http, url.path()) + } + } + _ => return Err(Error::Unrecognised { url: url.clone() }), + }; + + Ok((scheme, Path::from_url_path(path)?)) + } +} + +#[cfg(feature = "cloud")] +macro_rules! builder_opts { + ($builder:ty, $url:expr, $options:expr) => {{ + let builder = $options.into_iter().fold( + <$builder>::new().with_url($url.to_string()), + |builder, (key, value)| match key.as_ref().parse() { + Ok(k) => builder.with_config(k, value), + Err(_) => builder, + }, + ); + Box::new(builder.build()?) as _ + }}; +} + +/// Create an [`ObjectStore`] based on the provided `url` +/// +/// Returns +/// - An [`ObjectStore`] of the corresponding type +/// - The [`Path`] into the [`ObjectStore`] of the addressed resource +pub fn parse_url(url: &Url) -> Result<(Box, Path), super::Error> { + parse_url_opts(url, std::iter::empty::<(&str, &str)>()) +} + +/// Create an [`ObjectStore`] based on the provided `url` and options +/// +/// Returns +/// - An [`ObjectStore`] of the corresponding type +/// - The [`Path`] into the [`ObjectStore`] of the addressed resource +pub fn parse_url_opts( + url: &Url, + options: I, +) -> Result<(Box, Path), super::Error> +where + I: IntoIterator, + K: AsRef, + V: Into, +{ + let _options = options; + let (scheme, path) = ObjectStoreScheme::parse(url)?; + let path = Path::parse(path)?; + + let store = match scheme { + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + ObjectStoreScheme::Local => Box::new(LocalFileSystem::new()) as _, + ObjectStoreScheme::Memory => Box::new(InMemory::new()) as _, + #[cfg(feature = "aws")] + ObjectStoreScheme::AmazonS3 => { + builder_opts!(crate::aws::AmazonS3Builder, url, _options) + } + #[cfg(feature = "gcp")] + ObjectStoreScheme::GoogleCloudStorage => { + builder_opts!(crate::gcp::GoogleCloudStorageBuilder, url, _options) + } + #[cfg(feature = "azure")] + ObjectStoreScheme::MicrosoftAzure => { + builder_opts!(crate::azure::MicrosoftAzureBuilder, url, _options) + } + #[cfg(feature = "http")] + ObjectStoreScheme::Http => { + let url = &url[..url::Position::BeforePath]; + builder_opts!(crate::http::HttpBuilder, url, _options) + } + #[cfg(not(all( + feature = "aws", + feature = "azure", + feature = "gcp", + feature = "http", + not(target_arch = "wasm32") + )))] + s => { + return Err(super::Error::Generic { + store: "parse_url", + source: format!("feature for {s:?} not enabled").into(), + }) + } + }; + + Ok((store, path)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse() { + let cases = [ + ("file:/path", (ObjectStoreScheme::Local, "path")), + ("file:///path", (ObjectStoreScheme::Local, "path")), + ("memory:/path", (ObjectStoreScheme::Memory, "path")), + ("memory:///", (ObjectStoreScheme::Memory, "")), + ("s3://bucket/path", (ObjectStoreScheme::AmazonS3, "path")), + ("s3a://bucket/path", (ObjectStoreScheme::AmazonS3, "path")), + ( + "https://s3.region.amazonaws.com/bucket", + (ObjectStoreScheme::AmazonS3, ""), + ), + ( + "https://s3.region.amazonaws.com/bucket/path", + (ObjectStoreScheme::AmazonS3, "path"), + ), + ( + "https://bucket.s3.region.amazonaws.com", + (ObjectStoreScheme::AmazonS3, ""), + ), + ( + "https://ACCOUNT_ID.r2.cloudflarestorage.com/bucket", + (ObjectStoreScheme::AmazonS3, ""), + ), + ( + "https://ACCOUNT_ID.r2.cloudflarestorage.com/bucket/path", + (ObjectStoreScheme::AmazonS3, "path"), + ), + ( + "abfs://container/path", + (ObjectStoreScheme::MicrosoftAzure, "path"), + ), + ( + "abfs://file_system@account_name.dfs.core.windows.net/path", + (ObjectStoreScheme::MicrosoftAzure, "path"), + ), + ( + "abfss://file_system@account_name.dfs.core.windows.net/path", + (ObjectStoreScheme::MicrosoftAzure, "path"), + ), + ( + "https://account.dfs.core.windows.net", + (ObjectStoreScheme::MicrosoftAzure, ""), + ), + ( + "https://account.blob.core.windows.net", + (ObjectStoreScheme::MicrosoftAzure, ""), + ), + ( + "gs://bucket/path", + (ObjectStoreScheme::GoogleCloudStorage, "path"), + ), + ( + "gs://test.example.com/path", + (ObjectStoreScheme::GoogleCloudStorage, "path"), + ), + ("http://mydomain/path", (ObjectStoreScheme::Http, "path")), + ("https://mydomain/path", (ObjectStoreScheme::Http, "path")), + ( + "s3://bucket/foo%20bar", + (ObjectStoreScheme::AmazonS3, "foo bar"), + ), + ( + "https://foo/bar%20baz", + (ObjectStoreScheme::Http, "bar baz"), + ), + ( + "file:///bar%252Efoo", + (ObjectStoreScheme::Local, "bar%2Efoo"), + ), + ( + "abfss://file_system@account.dfs.fabric.microsoft.com/", + (ObjectStoreScheme::MicrosoftAzure, ""), + ), + ( + "abfss://file_system@account.dfs.fabric.microsoft.com/", + (ObjectStoreScheme::MicrosoftAzure, ""), + ), + ( + "https://account.dfs.fabric.microsoft.com/", + (ObjectStoreScheme::MicrosoftAzure, ""), + ), + ( + "https://account.dfs.fabric.microsoft.com/container", + (ObjectStoreScheme::MicrosoftAzure, "container"), + ), + ( + "https://account.blob.fabric.microsoft.com/", + (ObjectStoreScheme::MicrosoftAzure, ""), + ), + ( + "https://account.blob.fabric.microsoft.com/container", + (ObjectStoreScheme::MicrosoftAzure, "container"), + ), + ]; + + for (s, (expected_scheme, expected_path)) in cases { + let url = Url::parse(s).unwrap(); + let (scheme, path) = ObjectStoreScheme::parse(&url).unwrap(); + + assert_eq!(scheme, expected_scheme, "{s}"); + assert_eq!(path, Path::parse(expected_path).unwrap(), "{s}"); + } + + let neg_cases = [ + "unix:/run/foo.socket", + "file://remote/path", + "memory://remote/", + ]; + for s in neg_cases { + let url = Url::parse(s).unwrap(); + assert!(ObjectStoreScheme::parse(&url).is_err()); + } + } + + #[test] + fn test_url_spaces() { + let url = Url::parse("file:///my file with spaces").unwrap(); + assert_eq!(url.path(), "/my%20file%20with%20spaces"); + let (_, path) = parse_url(&url).unwrap(); + assert_eq!(path.as_ref(), "my file with spaces"); + } + + #[tokio::test] + #[cfg(feature = "http")] + async fn test_url_http() { + use crate::client::mock_server::MockServer; + use http::{header::USER_AGENT, Response}; + + let server = MockServer::new().await; + + server.push_fn(|r| { + assert_eq!(r.uri().path(), "/foo/bar"); + assert_eq!(r.headers().get(USER_AGENT).unwrap(), "test_url"); + Response::new(String::new()) + }); + + let test = format!("{}/foo/bar", server.url()); + let opts = [("user_agent", "test_url"), ("allow_http", "true")]; + let url = test.parse().unwrap(); + let (store, path) = parse_url_opts(&url, opts).unwrap(); + assert_eq!(path.as_ref(), "foo/bar"); + store.get(&path).await.unwrap(); + + server.shutdown().await; + } +} diff --git a/src/path/mod.rs b/src/path/mod.rs new file mode 100644 index 0000000..f8affe8 --- /dev/null +++ b/src/path/mod.rs @@ -0,0 +1,614 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Path abstraction for Object Storage + +use itertools::Itertools; +use percent_encoding::percent_decode; +use std::fmt::Formatter; +#[cfg(not(target_arch = "wasm32"))] +use url::Url; + +/// The delimiter to separate object namespaces, creating a directory structure. +pub const DELIMITER: &str = "/"; + +/// The path delimiter as a single byte +pub const DELIMITER_BYTE: u8 = DELIMITER.as_bytes()[0]; + +mod parts; + +pub use parts::{InvalidPart, PathPart}; + +/// Error returned by [`Path::parse`] +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error { + /// Error when there's an empty segment between two slashes `/` in the path + #[error("Path \"{}\" contained empty path segment", path)] + EmptySegment { + /// The source path + path: String, + }, + + /// Error when an invalid segment is encountered in the given path + #[error("Error parsing Path \"{}\": {}", path, source)] + BadSegment { + /// The source path + path: String, + /// The part containing the error + source: InvalidPart, + }, + + /// Error when path cannot be canonicalized + #[error("Failed to canonicalize path \"{}\": {}", path.display(), source)] + Canonicalize { + /// The source path + path: std::path::PathBuf, + /// The underlying error + source: std::io::Error, + }, + + /// Error when the path is not a valid URL + #[error("Unable to convert path \"{}\" to URL", path.display())] + InvalidPath { + /// The source path + path: std::path::PathBuf, + }, + + /// Error when a path contains non-unicode characters + #[error("Path \"{}\" contained non-unicode characters: {}", path, source)] + NonUnicode { + /// The source path + path: String, + /// The underlying `UTF8Error` + source: std::str::Utf8Error, + }, + + /// Error when the a path doesn't start with given prefix + #[error("Path {} does not start with prefix {}", path, prefix)] + PrefixMismatch { + /// The source path + path: String, + /// The mismatched prefix + prefix: String, + }, +} + +/// A parsed path representation that can be safely written to object storage +/// +/// A [`Path`] maintains the following invariants: +/// +/// * Paths are delimited by `/` +/// * Paths do not contain leading or trailing `/` +/// * Paths do not contain relative path segments, i.e. `.` or `..` +/// * Paths do not contain empty path segments +/// * Paths do not contain any ASCII control characters +/// +/// There are no enforced restrictions on path length, however, it should be noted that most +/// object stores do not permit paths longer than 1024 bytes, and many filesystems do not +/// support path segments longer than 255 bytes. +/// +/// # Encode +/// +/// In theory object stores support any UTF-8 character sequence, however, certain character +/// sequences cause compatibility problems with some applications and protocols. Additionally +/// some filesystems may impose character restrictions, see [`LocalFileSystem`]. As such the +/// naming guidelines for [S3], [GCS] and [Azure Blob Storage] all recommend sticking to a +/// limited character subset. +/// +/// [S3]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html +/// [GCS]: https://cloud.google.com/storage/docs/naming-objects +/// [Azure Blob Storage]: https://docs.microsoft.com/en-us/rest/api/storageservices/Naming-and-Referencing-Containers--Blobs--and-Metadata#blob-names +/// +/// A string containing potentially problematic path segments can therefore be encoded to a [`Path`] +/// using [`Path::from`] or [`Path::from_iter`]. This will percent encode any problematic +/// segments according to [RFC 1738]. +/// +/// ``` +/// # use object_store::path::Path; +/// assert_eq!(Path::from("foo/bar").as_ref(), "foo/bar"); +/// assert_eq!(Path::from("foo//bar").as_ref(), "foo/bar"); +/// assert_eq!(Path::from("foo/../bar").as_ref(), "foo/%2E%2E/bar"); +/// assert_eq!(Path::from("/").as_ref(), ""); +/// assert_eq!(Path::from_iter(["foo", "foo/bar"]).as_ref(), "foo/foo%2Fbar"); +/// ``` +/// +/// Note: if provided with an already percent encoded string, this will encode it again +/// +/// ``` +/// # use object_store::path::Path; +/// assert_eq!(Path::from("foo/foo%2Fbar").as_ref(), "foo/foo%252Fbar"); +/// ``` +/// +/// # Parse +/// +/// Alternatively a [`Path`] can be parsed from an existing string, returning an +/// error if it is invalid. Unlike the encoding methods above, this will permit +/// arbitrary unicode, including percent encoded sequences. +/// +/// ``` +/// # use object_store::path::Path; +/// assert_eq!(Path::parse("/foo/foo%2Fbar").unwrap().as_ref(), "foo/foo%2Fbar"); +/// Path::parse("..").unwrap_err(); // Relative path segments are disallowed +/// Path::parse("/foo//").unwrap_err(); // Empty path segments are disallowed +/// Path::parse("\x00").unwrap_err(); // ASCII control characters are disallowed +/// ``` +/// +/// [RFC 1738]: https://www.ietf.org/rfc/rfc1738.txt +/// [`LocalFileSystem`]: crate::local::LocalFileSystem +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct Path { + /// The raw path with no leading or trailing delimiters + raw: String, +} + +impl Path { + /// Parse a string as a [`Path`], returning a [`Error`] if invalid, + /// as defined on the docstring for [`Path`] + /// + /// Note: this will strip any leading `/` or trailing `/` + pub fn parse(path: impl AsRef) -> Result { + let path = path.as_ref(); + + let stripped = path.strip_prefix(DELIMITER).unwrap_or(path); + if stripped.is_empty() { + return Ok(Default::default()); + } + + let stripped = stripped.strip_suffix(DELIMITER).unwrap_or(stripped); + + for segment in stripped.split(DELIMITER) { + if segment.is_empty() { + return Err(Error::EmptySegment { path: path.into() }); + } + + PathPart::parse(segment).map_err(|source| { + let path = path.into(); + Error::BadSegment { source, path } + })?; + } + + Ok(Self { + raw: stripped.to_string(), + }) + } + + #[cfg(not(target_arch = "wasm32"))] + /// Convert a filesystem path to a [`Path`] relative to the filesystem root + /// + /// This will return an error if the path contains illegal character sequences + /// as defined on the docstring for [`Path`] or does not exist + /// + /// Note: this will canonicalize the provided path, resolving any symlinks + pub fn from_filesystem_path(path: impl AsRef) -> Result { + let absolute = std::fs::canonicalize(&path).map_err(|source| { + let path = path.as_ref().into(); + Error::Canonicalize { source, path } + })?; + + Self::from_absolute_path(absolute) + } + + #[cfg(not(target_arch = "wasm32"))] + /// Convert an absolute filesystem path to a [`Path`] relative to the filesystem root + /// + /// This will return an error if the path contains illegal character sequences, + /// as defined on the docstring for [`Path`], or `base` is not an absolute path + pub fn from_absolute_path(path: impl AsRef) -> Result { + Self::from_absolute_path_with_base(path, None) + } + + #[cfg(not(target_arch = "wasm32"))] + /// Convert a filesystem path to a [`Path`] relative to the provided base + /// + /// This will return an error if the path contains illegal character sequences, + /// as defined on the docstring for [`Path`], or `base` does not refer to a parent + /// path of `path`, or `base` is not an absolute path + pub(crate) fn from_absolute_path_with_base( + path: impl AsRef, + base: Option<&Url>, + ) -> Result { + let url = absolute_path_to_url(path)?; + let path = match base { + Some(prefix) => { + url.path() + .strip_prefix(prefix.path()) + .ok_or_else(|| Error::PrefixMismatch { + path: url.path().to_string(), + prefix: prefix.to_string(), + })? + } + None => url.path(), + }; + + // Reverse any percent encoding performed by conversion to URL + Self::from_url_path(path) + } + + /// Parse a url encoded string as a [`Path`], returning a [`Error`] if invalid + /// + /// This will return an error if the path contains illegal character sequences + /// as defined on the docstring for [`Path`] + pub fn from_url_path(path: impl AsRef) -> Result { + let path = path.as_ref(); + let decoded = percent_decode(path.as_bytes()) + .decode_utf8() + .map_err(|source| { + let path = path.into(); + Error::NonUnicode { source, path } + })?; + + Self::parse(decoded) + } + + /// Returns the [`PathPart`] of this [`Path`] + pub fn parts(&self) -> impl Iterator> { + self.raw + .split_terminator(DELIMITER) + .map(|s| PathPart { raw: s.into() }) + } + + /// Returns the last path segment containing the filename stored in this [`Path`] + pub fn filename(&self) -> Option<&str> { + match self.raw.is_empty() { + true => None, + false => self.raw.rsplit(DELIMITER).next(), + } + } + + /// Returns the extension of the file stored in this [`Path`], if any + pub fn extension(&self) -> Option<&str> { + self.filename() + .and_then(|f| f.rsplit_once('.')) + .and_then(|(_, extension)| { + if extension.is_empty() { + None + } else { + Some(extension) + } + }) + } + + /// Returns an iterator of the [`PathPart`] of this [`Path`] after `prefix` + /// + /// Returns `None` if the prefix does not match + pub fn prefix_match(&self, prefix: &Self) -> Option> + '_> { + let mut stripped = self.raw.strip_prefix(&prefix.raw)?; + if !stripped.is_empty() && !prefix.raw.is_empty() { + stripped = stripped.strip_prefix(DELIMITER)?; + } + let iter = stripped + .split_terminator(DELIMITER) + .map(|x| PathPart { raw: x.into() }); + Some(iter) + } + + /// Returns true if this [`Path`] starts with `prefix` + pub fn prefix_matches(&self, prefix: &Self) -> bool { + self.prefix_match(prefix).is_some() + } + + /// Creates a new child of this [`Path`] + pub fn child<'a>(&self, child: impl Into>) -> Self { + let raw = match self.raw.is_empty() { + true => format!("{}", child.into().raw), + false => format!("{}{}{}", self.raw, DELIMITER, child.into().raw), + }; + + Self { raw } + } +} + +impl AsRef for Path { + fn as_ref(&self) -> &str { + &self.raw + } +} + +impl From<&str> for Path { + fn from(path: &str) -> Self { + Self::from_iter(path.split(DELIMITER)) + } +} + +impl From for Path { + fn from(path: String) -> Self { + Self::from_iter(path.split(DELIMITER)) + } +} + +impl From for String { + fn from(path: Path) -> Self { + path.raw + } +} + +impl std::fmt::Display for Path { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.raw.fmt(f) + } +} + +impl<'a, I> FromIterator for Path +where + I: Into>, +{ + fn from_iter>(iter: T) -> Self { + let raw = T::into_iter(iter) + .map(|s| s.into()) + .filter(|s| !s.raw.is_empty()) + .map(|s| s.raw) + .join(DELIMITER); + + Self { raw } + } +} + +#[cfg(not(target_arch = "wasm32"))] +/// Given an absolute filesystem path convert it to a URL representation without canonicalization +pub(crate) fn absolute_path_to_url(path: impl AsRef) -> Result { + Url::from_file_path(&path).map_err(|_| Error::InvalidPath { + path: path.as_ref().into(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cloud_prefix_with_trailing_delimiter() { + // Use case: files exist in object storage named `foo/bar.json` and + // `foo_test.json`. A search for the prefix `foo/` should return + // `foo/bar.json` but not `foo_test.json'. + let prefix = Path::from_iter(["test"]); + assert_eq!(prefix.as_ref(), "test"); + } + + #[test] + fn push_encodes() { + let location = Path::from_iter(["foo/bar", "baz%2Ftest"]); + assert_eq!(location.as_ref(), "foo%2Fbar/baz%252Ftest"); + } + + #[test] + fn test_parse() { + assert_eq!(Path::parse("/").unwrap().as_ref(), ""); + assert_eq!(Path::parse("").unwrap().as_ref(), ""); + + let err = Path::parse("//").unwrap_err(); + assert!(matches!(err, Error::EmptySegment { .. })); + + assert_eq!(Path::parse("/foo/bar/").unwrap().as_ref(), "foo/bar"); + assert_eq!(Path::parse("foo/bar/").unwrap().as_ref(), "foo/bar"); + assert_eq!(Path::parse("foo/bar").unwrap().as_ref(), "foo/bar"); + + let err = Path::parse("foo///bar").unwrap_err(); + assert!(matches!(err, Error::EmptySegment { .. })); + } + + #[test] + fn convert_raw_before_partial_eq() { + // dir and file_name + let cloud = Path::from("test_dir/test_file.json"); + let built = Path::from_iter(["test_dir", "test_file.json"]); + + assert_eq!(built, cloud); + + // dir and file_name w/o dot + let cloud = Path::from("test_dir/test_file"); + let built = Path::from_iter(["test_dir", "test_file"]); + + assert_eq!(built, cloud); + + // dir, no file + let cloud = Path::from("test_dir/"); + let built = Path::from_iter(["test_dir"]); + assert_eq!(built, cloud); + + // file_name, no dir + let cloud = Path::from("test_file.json"); + let built = Path::from_iter(["test_file.json"]); + assert_eq!(built, cloud); + + // empty + let cloud = Path::from(""); + let built = Path::from_iter(["", ""]); + + assert_eq!(built, cloud); + } + + #[test] + fn parts_after_prefix_behavior() { + let existing_path = Path::from("apple/bear/cow/dog/egg.json"); + + // Prefix with one directory + let prefix = Path::from("apple"); + let expected_parts: Vec> = vec!["bear", "cow", "dog", "egg.json"] + .into_iter() + .map(Into::into) + .collect(); + let parts: Vec<_> = existing_path.prefix_match(&prefix).unwrap().collect(); + assert_eq!(parts, expected_parts); + + // Prefix with two directories + let prefix = Path::from("apple/bear"); + let expected_parts: Vec> = vec!["cow", "dog", "egg.json"] + .into_iter() + .map(Into::into) + .collect(); + let parts: Vec<_> = existing_path.prefix_match(&prefix).unwrap().collect(); + assert_eq!(parts, expected_parts); + + // Not a prefix + let prefix = Path::from("cow"); + assert!(existing_path.prefix_match(&prefix).is_none()); + + // Prefix with a partial directory + let prefix = Path::from("ap"); + assert!(existing_path.prefix_match(&prefix).is_none()); + + // Prefix matches but there aren't any parts after it + let existing = Path::from("apple/bear/cow/dog"); + + assert_eq!(existing.prefix_match(&existing).unwrap().count(), 0); + assert_eq!(Path::default().parts().count(), 0); + } + + #[test] + fn prefix_matches() { + let haystack = Path::from_iter(["foo/bar", "baz%2Ftest", "something"]); + // self starts with self + assert!( + haystack.prefix_matches(&haystack), + "{haystack:?} should have started with {haystack:?}" + ); + + // a longer prefix doesn't match + let needle = haystack.child("longer now"); + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} shouldn't have started with {needle:?}" + ); + + // one dir prefix matches + let needle = Path::from_iter(["foo/bar"]); + assert!( + haystack.prefix_matches(&needle), + "{haystack:?} should have started with {needle:?}" + ); + + // two dir prefix matches + let needle = needle.child("baz%2Ftest"); + assert!( + haystack.prefix_matches(&needle), + "{haystack:?} should have started with {needle:?}" + ); + + // partial dir prefix doesn't match + let needle = Path::from_iter(["f"]); + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} should not have started with {needle:?}" + ); + + // one dir and one partial dir doesn't match + let needle = Path::from_iter(["foo/bar", "baz"]); + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} should not have started with {needle:?}" + ); + + // empty prefix matches + let needle = Path::from(""); + assert!( + haystack.prefix_matches(&needle), + "{haystack:?} should have started with {needle:?}" + ); + } + + #[test] + fn prefix_matches_with_file_name() { + let haystack = Path::from_iter(["foo/bar", "baz%2Ftest", "something", "foo.segment"]); + + // All directories match and file name is a prefix + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "something", "foo"]); + + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} should not have started with {needle:?}" + ); + + // All directories match but file name is not a prefix + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "something", "e"]); + + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} should not have started with {needle:?}" + ); + + // Not all directories match; file name is a prefix of the next directory; this + // does not match + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "s"]); + + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} should not have started with {needle:?}" + ); + + // Not all directories match; file name is NOT a prefix of the next directory; + // no match + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "p"]); + + assert!( + !haystack.prefix_matches(&needle), + "{haystack:?} should not have started with {needle:?}" + ); + } + + #[test] + fn path_containing_spaces() { + let a = Path::from_iter(["foo bar", "baz"]); + let b = Path::from("foo bar/baz"); + let c = Path::parse("foo bar/baz").unwrap(); + + assert_eq!(a.raw, "foo bar/baz"); + assert_eq!(a.raw, b.raw); + assert_eq!(b.raw, c.raw); + } + + #[test] + fn from_url_path() { + let a = Path::from_url_path("foo%20bar").unwrap(); + let b = Path::from_url_path("foo/%2E%2E/bar").unwrap_err(); + let c = Path::from_url_path("foo%2F%252E%252E%2Fbar").unwrap(); + let d = Path::from_url_path("foo/%252E%252E/bar").unwrap(); + let e = Path::from_url_path("%48%45%4C%4C%4F").unwrap(); + let f = Path::from_url_path("foo/%FF/as").unwrap_err(); + + assert_eq!(a.raw, "foo bar"); + assert!(matches!(b, Error::BadSegment { .. })); + assert_eq!(c.raw, "foo/%2E%2E/bar"); + assert_eq!(d.raw, "foo/%2E%2E/bar"); + assert_eq!(e.raw, "HELLO"); + assert!(matches!(f, Error::NonUnicode { .. })); + } + + #[test] + fn filename_from_path() { + let a = Path::from("foo/bar"); + let b = Path::from("foo/bar.baz"); + let c = Path::from("foo.bar/baz"); + + assert_eq!(a.filename(), Some("bar")); + assert_eq!(b.filename(), Some("bar.baz")); + assert_eq!(c.filename(), Some("baz")); + } + + #[test] + fn file_extension() { + let a = Path::from("foo/bar"); + let b = Path::from("foo/bar.baz"); + let c = Path::from("foo.bar/baz"); + let d = Path::from("foo.bar/baz.qux"); + + assert_eq!(a.extension(), None); + assert_eq!(b.extension(), Some("baz")); + assert_eq!(c.extension(), None); + assert_eq!(d.extension(), Some("qux")); + } +} diff --git a/src/path/parts.rs b/src/path/parts.rs new file mode 100644 index 0000000..9c6612b --- /dev/null +++ b/src/path/parts.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use percent_encoding::{percent_encode, AsciiSet, CONTROLS}; +use std::borrow::Cow; + +use crate::path::DELIMITER_BYTE; + +/// Error returned by [`PathPart::parse`] +#[derive(Debug, thiserror::Error)] +#[error( + "Encountered illegal character sequence \"{}\" whilst parsing path segment \"{}\"", + illegal, + segment +)] +#[allow(missing_copy_implementations)] +pub struct InvalidPart { + segment: String, + illegal: String, +} + +/// The PathPart type exists to validate the directory/file names that form part +/// of a path. +/// +/// A [`PathPart`] is guaranteed to: +/// +/// * Contain no ASCII control characters or `/` +/// * Not be a relative path segment, i.e. `.` or `..` +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] +pub struct PathPart<'a> { + pub(super) raw: Cow<'a, str>, +} + +impl<'a> PathPart<'a> { + /// Parse the provided path segment as a [`PathPart`] returning an error if invalid + pub fn parse(segment: &'a str) -> Result { + if segment == "." || segment == ".." { + return Err(InvalidPart { + segment: segment.to_string(), + illegal: segment.to_string(), + }); + } + + for c in segment.chars() { + if c.is_ascii_control() || c == '/' { + return Err(InvalidPart { + segment: segment.to_string(), + // This is correct as only single byte characters up to this point + illegal: c.to_string(), + }); + } + } + + Ok(Self { + raw: segment.into(), + }) + } +} + +/// Characters we want to encode. +const INVALID: &AsciiSet = &CONTROLS + // The delimiter we are reserving for internal hierarchy + .add(DELIMITER_BYTE) + // Characters AWS recommends avoiding for object keys + // https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingMetadata.html + .add(b'\\') + .add(b'{') + .add(b'^') + .add(b'}') + .add(b'%') + .add(b'`') + .add(b']') + .add(b'"') // " <-- my editor is confused about double quotes within single quotes + .add(b'>') + .add(b'[') + .add(b'~') + .add(b'<') + .add(b'#') + .add(b'|') + // Characters Google Cloud Storage recommends avoiding for object names + // https://cloud.google.com/storage/docs/naming-objects + .add(b'\r') + .add(b'\n') + .add(b'*') + .add(b'?'); + +impl<'a> From<&'a [u8]> for PathPart<'a> { + fn from(v: &'a [u8]) -> Self { + let inner = match v { + // We don't want to encode `.` generally, but we do want to disallow parts of paths + // to be equal to `.` or `..` to prevent file system traversal shenanigans. + b"." => "%2E".into(), + b".." => "%2E%2E".into(), + other => percent_encode(other, INVALID).into(), + }; + Self { raw: inner } + } +} + +impl<'a> From<&'a str> for PathPart<'a> { + fn from(v: &'a str) -> Self { + Self::from(v.as_bytes()) + } +} + +impl From for PathPart<'static> { + fn from(s: String) -> Self { + Self { + raw: Cow::Owned(PathPart::from(s.as_str()).raw.into_owned()), + } + } +} + +impl AsRef for PathPart<'_> { + fn as_ref(&self) -> &str { + self.raw.as_ref() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn path_part_delimiter_gets_encoded() { + let part: PathPart<'_> = "foo/bar".into(); + assert_eq!(part.raw, "foo%2Fbar"); + } + + #[test] + fn path_part_given_already_encoded_string() { + let part: PathPart<'_> = "foo%2Fbar".into(); + assert_eq!(part.raw, "foo%252Fbar"); + } + + #[test] + fn path_part_cant_be_one_dot() { + let part: PathPart<'_> = ".".into(); + assert_eq!(part.raw, "%2E"); + } + + #[test] + fn path_part_cant_be_two_dots() { + let part: PathPart<'_> = "..".into(); + assert_eq!(part.raw, "%2E%2E"); + } + + #[test] + fn path_part_parse() { + PathPart::parse("foo").unwrap(); + PathPart::parse("foo/bar").unwrap_err(); + + // Test percent-encoded path + PathPart::parse("foo%2Fbar").unwrap(); + PathPart::parse("L%3ABC.parquet").unwrap(); + + // Test path containing bad escape sequence + PathPart::parse("%Z").unwrap(); + PathPart::parse("%%").unwrap(); + } +} diff --git a/src/payload.rs b/src/payload.rs new file mode 100644 index 0000000..055336b --- /dev/null +++ b/src/payload.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytes::Bytes; +use std::sync::Arc; + +/// A cheaply cloneable, ordered collection of [`Bytes`] +#[derive(Debug, Clone)] +pub struct PutPayload(Arc<[Bytes]>); + +impl Default for PutPayload { + fn default() -> Self { + Self(Arc::new([])) + } +} + +impl PutPayload { + /// Create a new empty [`PutPayload`] + pub fn new() -> Self { + Self::default() + } + + /// Creates a [`PutPayload`] from a static slice + pub fn from_static(s: &'static [u8]) -> Self { + s.into() + } + + /// Creates a [`PutPayload`] from a [`Bytes`] + pub fn from_bytes(s: Bytes) -> Self { + s.into() + } + + /// Returns the total length of the [`Bytes`] in this payload + pub fn content_length(&self) -> usize { + self.0.iter().map(|b| b.len()).sum() + } + + /// Returns an iterator over the [`Bytes`] in this payload + pub fn iter(&self) -> PutPayloadIter<'_> { + PutPayloadIter(self.0.iter()) + } +} + +impl AsRef<[Bytes]> for PutPayload { + fn as_ref(&self) -> &[Bytes] { + self.0.as_ref() + } +} + +impl<'a> IntoIterator for &'a PutPayload { + type Item = &'a Bytes; + type IntoIter = PutPayloadIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl IntoIterator for PutPayload { + type Item = Bytes; + type IntoIter = PutPayloadIntoIter; + + fn into_iter(self) -> Self::IntoIter { + PutPayloadIntoIter { + payload: self, + idx: 0, + } + } +} + +/// An iterator over [`PutPayload`] +#[derive(Debug)] +pub struct PutPayloadIter<'a>(std::slice::Iter<'a, Bytes>); + +impl<'a> Iterator for PutPayloadIter<'a> { + type Item = &'a Bytes; + + fn next(&mut self) -> Option { + self.0.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +/// An owning iterator of [`PutPayload`] +#[derive(Debug)] +pub struct PutPayloadIntoIter { + payload: PutPayload, + idx: usize, +} + +impl Iterator for PutPayloadIntoIter { + type Item = Bytes; + + fn next(&mut self) -> Option { + let p = self.payload.0.get(self.idx)?.clone(); + self.idx += 1; + Some(p) + } + + fn size_hint(&self) -> (usize, Option) { + let l = self.payload.0.len() - self.idx; + (l, Some(l)) + } +} + +impl From for PutPayload { + fn from(value: Bytes) -> Self { + Self(Arc::new([value])) + } +} + +impl From> for PutPayload { + fn from(value: Vec) -> Self { + Self(Arc::new([value.into()])) + } +} + +impl From<&'static str> for PutPayload { + fn from(value: &'static str) -> Self { + Bytes::from(value).into() + } +} + +impl From<&'static [u8]> for PutPayload { + fn from(value: &'static [u8]) -> Self { + Bytes::from(value).into() + } +} + +impl From for PutPayload { + fn from(value: String) -> Self { + Bytes::from(value).into() + } +} + +impl FromIterator for PutPayload { + fn from_iter>(iter: T) -> Self { + Bytes::from_iter(iter).into() + } +} + +impl FromIterator for PutPayload { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl From for Bytes { + fn from(value: PutPayload) -> Self { + match value.0.len() { + 0 => Self::new(), + 1 => value.0[0].clone(), + _ => { + let mut buf = Vec::with_capacity(value.content_length()); + value.iter().for_each(|x| buf.extend_from_slice(x)); + buf.into() + } + } + } +} + +/// A builder for [`PutPayload`] that avoids reallocating memory +/// +/// Data is allocated in fixed blocks, which are flushed to [`Bytes`] once full. +/// Unlike [`Vec`] this avoids needing to repeatedly reallocate blocks of memory, +/// which typically involves copying all the previously written data to a new +/// contiguous memory region. +#[derive(Debug)] +pub struct PutPayloadMut { + len: usize, + completed: Vec, + in_progress: Vec, + block_size: usize, +} + +impl Default for PutPayloadMut { + fn default() -> Self { + Self { + len: 0, + completed: vec![], + in_progress: vec![], + + block_size: 8 * 1024, + } + } +} + +impl PutPayloadMut { + /// Create a new [`PutPayloadMut`] + pub fn new() -> Self { + Self::default() + } + + /// Configures the minimum allocation size + /// + /// Defaults to 8KB + pub fn with_block_size(self, block_size: usize) -> Self { + Self { block_size, ..self } + } + + /// Write bytes into this [`PutPayloadMut`] + /// + /// If there is an in-progress block, data will be first written to it, flushing + /// it to [`Bytes`] once full. If data remains to be written, a new block of memory + /// of at least the configured block size will be allocated, to hold the remaining data. + pub fn extend_from_slice(&mut self, slice: &[u8]) { + let remaining = self.in_progress.capacity() - self.in_progress.len(); + let to_copy = remaining.min(slice.len()); + + self.in_progress.extend_from_slice(&slice[..to_copy]); + if self.in_progress.capacity() == self.in_progress.len() { + let new_cap = self.block_size.max(slice.len() - to_copy); + let completed = std::mem::replace(&mut self.in_progress, Vec::with_capacity(new_cap)); + if !completed.is_empty() { + self.completed.push(completed.into()) + } + self.in_progress.extend_from_slice(&slice[to_copy..]) + } + self.len += slice.len(); + } + + /// Append a [`Bytes`] to this [`PutPayloadMut`] without copying + /// + /// This will close any currently buffered block populated by [`Self::extend_from_slice`], + /// and append `bytes` to this payload without copying. + pub fn push(&mut self, bytes: Bytes) { + if !self.in_progress.is_empty() { + let completed = std::mem::take(&mut self.in_progress); + self.completed.push(completed.into()) + } + self.len += bytes.len(); + self.completed.push(bytes); + } + + /// Returns `true` if this [`PutPayloadMut`] contains no bytes + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the total length of the [`Bytes`] in this payload + #[inline] + pub fn content_length(&self) -> usize { + self.len + } + + /// Convert into [`PutPayload`] + pub fn freeze(mut self) -> PutPayload { + if !self.in_progress.is_empty() { + let completed = std::mem::take(&mut self.in_progress).into(); + self.completed.push(completed); + } + PutPayload(self.completed.into()) + } +} + +impl From for PutPayload { + fn from(value: PutPayloadMut) -> Self { + value.freeze() + } +} + +#[cfg(test)] +mod test { + use crate::PutPayloadMut; + + #[test] + fn test_put_payload() { + let mut chunk = PutPayloadMut::new().with_block_size(23); + chunk.extend_from_slice(&[1; 16]); + chunk.extend_from_slice(&[2; 32]); + chunk.extend_from_slice(&[2; 5]); + chunk.extend_from_slice(&[2; 21]); + chunk.extend_from_slice(&[2; 40]); + chunk.extend_from_slice(&[0; 0]); + chunk.push("foobar".into()); + + let payload = chunk.freeze(); + assert_eq!(payload.content_length(), 120); + + let chunks = payload.as_ref(); + assert_eq!(chunks.len(), 6); + + assert_eq!(chunks[0].len(), 23); + assert_eq!(chunks[1].len(), 25); // 32 - (23 - 16) + assert_eq!(chunks[2].len(), 23); + assert_eq!(chunks[3].len(), 23); + assert_eq!(chunks[4].len(), 20); + assert_eq!(chunks[5].len(), 6); + } + + #[test] + fn test_content_length() { + let mut chunk = PutPayloadMut::new(); + chunk.push(vec![0; 23].into()); + assert_eq!(chunk.content_length(), 23); + chunk.extend_from_slice(&[0; 4]); + assert_eq!(chunk.content_length(), 27); + chunk.push(vec![0; 121].into()); + assert_eq!(chunk.content_length(), 148); + let payload = chunk.freeze(); + assert_eq!(payload.content_length(), 148); + } +} diff --git a/src/prefix.rs b/src/prefix.rs new file mode 100644 index 0000000..ac9803e --- /dev/null +++ b/src/prefix.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store wrapper handling a constant path prefix +use bytes::Bytes; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use std::ops::Range; + +use crate::path::Path; +use crate::{ + GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOpts, + PutOptions, PutPayload, PutResult, Result, +}; + +/// Store wrapper that applies a constant prefix to all paths handled by the store. +#[derive(Debug, Clone)] +pub struct PrefixStore { + prefix: Path, + inner: T, +} + +impl std::fmt::Display for PrefixStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "PrefixObjectStore({})", self.prefix.as_ref()) + } +} + +impl PrefixStore { + /// Create a new instance of [`PrefixStore`] + pub fn new(store: T, prefix: impl Into) -> Self { + Self { + prefix: prefix.into(), + inner: store, + } + } + + /// Create the full path from a path relative to prefix + fn full_path(&self, location: &Path) -> Path { + self.prefix.parts().chain(location.parts()).collect() + } + + /// Strip the constant prefix from a given path + fn strip_prefix(&self, path: Path) -> Path { + // Note cannot use match because of borrow checker + if let Some(suffix) = path.prefix_match(&self.prefix) { + return suffix.collect(); + } + path + } + + /// Strip the constant prefix from a given ObjectMeta + fn strip_meta(&self, meta: ObjectMeta) -> ObjectMeta { + ObjectMeta { + last_modified: meta.last_modified, + size: meta.size, + location: self.strip_prefix(meta.location), + e_tag: meta.e_tag, + version: None, + } + } +} + +// Note: This is a relative hack to move these two functions to pure functions so they don't rely +// on the `self` lifetime. Expected to be cleaned up before merge. +// +/// Strip the constant prefix from a given path +fn strip_prefix(prefix: &Path, path: Path) -> Path { + // Note cannot use match because of borrow checker + if let Some(suffix) = path.prefix_match(prefix) { + return suffix.collect(); + } + path +} + +/// Strip the constant prefix from a given ObjectMeta +fn strip_meta(prefix: &Path, meta: ObjectMeta) -> ObjectMeta { + ObjectMeta { + last_modified: meta.last_modified, + size: meta.size, + location: strip_prefix(prefix, meta.location), + e_tag: meta.e_tag, + version: None, + } +} +#[async_trait::async_trait] +impl ObjectStore for PrefixStore { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + let full_path = self.full_path(location); + self.inner.put(&full_path, payload).await + } + + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let full_path = self.full_path(location); + self.inner.put_opts(&full_path, payload, opts).await + } + + async fn put_multipart(&self, location: &Path) -> Result> { + let full_path = self.full_path(location); + self.inner.put_multipart(&full_path).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + let full_path = self.full_path(location); + self.inner.put_multipart_opts(&full_path, opts).await + } + + async fn get(&self, location: &Path) -> Result { + let full_path = self.full_path(location); + self.inner.get(&full_path).await + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let full_path = self.full_path(location); + self.inner.get_range(&full_path, range).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + let full_path = self.full_path(location); + self.inner.get_opts(&full_path, options).await + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> Result> { + let full_path = self.full_path(location); + self.inner.get_ranges(&full_path, ranges).await + } + + async fn head(&self, location: &Path) -> Result { + let full_path = self.full_path(location); + let meta = self.inner.head(&full_path).await?; + Ok(self.strip_meta(meta)) + } + + async fn delete(&self, location: &Path) -> Result<()> { + let full_path = self.full_path(location); + self.inner.delete(&full_path).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let prefix = self.full_path(prefix.unwrap_or(&Path::default())); + let s = self.inner.list(Some(&prefix)); + let slf_prefix = self.prefix.clone(); + s.map_ok(move |meta| strip_meta(&slf_prefix, meta)).boxed() + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + let offset = self.full_path(offset); + let prefix = self.full_path(prefix.unwrap_or(&Path::default())); + let s = self.inner.list_with_offset(Some(&prefix), &offset); + let slf_prefix = self.prefix.clone(); + s.map_ok(move |meta| strip_meta(&slf_prefix, meta)).boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let prefix = self.full_path(prefix.unwrap_or(&Path::default())); + self.inner + .list_with_delimiter(Some(&prefix)) + .await + .map(|lst| ListResult { + common_prefixes: lst + .common_prefixes + .into_iter() + .map(|p| self.strip_prefix(p)) + .collect(), + objects: lst + .objects + .into_iter() + .map(|meta| self.strip_meta(meta)) + .collect(), + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let full_from = self.full_path(from); + let full_to = self.full_path(to); + self.inner.copy(&full_from, &full_to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + let full_from = self.full_path(from); + let full_to = self.full_path(to); + self.inner.rename(&full_from, &full_to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let full_from = self.full_path(from); + let full_to = self.full_path(to); + self.inner.copy_if_not_exists(&full_from, &full_to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let full_from = self.full_path(from); + let full_to = self.full_path(to); + self.inner.rename_if_not_exists(&full_from, &full_to).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::integration::*; + use crate::local::LocalFileSystem; + + use tempfile::TempDir; + + #[tokio::test] + async fn prefix_test() { + let root = TempDir::new().unwrap(); + let inner = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let integration = PrefixStore::new(inner, "prefix"); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn prefix_test_applies_prefix() { + let tmpdir = TempDir::new().unwrap(); + let local = LocalFileSystem::new_with_prefix(tmpdir.path()).unwrap(); + + let location = Path::from("prefix/test_file.json"); + let data = Bytes::from("arbitrary data"); + + local.put(&location, data.clone().into()).await.unwrap(); + + let prefix = PrefixStore::new(local, "prefix"); + let location_prefix = Path::from("test_file.json"); + + let content_list = flatten_list_stream(&prefix, None).await.unwrap(); + assert_eq!(content_list, &[location_prefix.clone()]); + + let root = Path::from("/"); + let content_list = flatten_list_stream(&prefix, Some(&root)).await.unwrap(); + assert_eq!(content_list, &[location_prefix.clone()]); + + let read_data = prefix + .get(&location_prefix) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, data); + + let target_prefix = Path::from("/test_written.json"); + prefix + .put(&target_prefix, data.clone().into()) + .await + .unwrap(); + + prefix.delete(&location_prefix).await.unwrap(); + + let local = LocalFileSystem::new_with_prefix(tmpdir.path()).unwrap(); + + let err = local.get(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + let location = Path::from("prefix/test_written.json"); + let read_data = local.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(&*read_data, data) + } +} diff --git a/src/signer.rs b/src/signer.rs new file mode 100644 index 0000000..da55c68 --- /dev/null +++ b/src/signer.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Abstraction of signed URL generation for those object store implementations that support it + +use crate::{path::Path, Result}; +use async_trait::async_trait; +use reqwest::Method; +use std::{fmt, time::Duration}; +use url::Url; + +/// Universal API to generate presigned URLs from multiple object store services. +#[async_trait] +pub trait Signer: Send + Sync + fmt::Debug + 'static { + /// Given the intended [`Method`] and [`Path`] to use and the desired length of time for which + /// the URL should be valid, return a signed [`Url`] created with the object store + /// implementation's credentials such that the URL can be handed to something that doesn't have + /// access to the object store's credentials, to allow limited access to the object store. + async fn signed_url(&self, method: Method, path: &Path, expires_in: Duration) -> Result; + + /// Generate signed urls for multiple paths. + /// + /// See [`Signer::signed_url`] for more details. + async fn signed_urls( + &self, + method: Method, + paths: &[Path], + expires_in: Duration, + ) -> Result> { + let mut urls = Vec::with_capacity(paths.len()); + for path in paths { + urls.push(self.signed_url(method.clone(), path, expires_in).await?); + } + Ok(urls) + } +} diff --git a/src/tags.rs b/src/tags.rs new file mode 100644 index 0000000..fa6e591 --- /dev/null +++ b/src/tags.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use url::form_urlencoded::Serializer; + +/// A collection of key value pairs used to annotate objects +/// +/// +/// +#[derive(Debug, Clone, Default, Eq, PartialEq)] +pub struct TagSet(String); + +impl TagSet { + /// Append a key value pair to this [`TagSet`] + /// + /// Stores have different restrictions on what characters are permitted, + /// for portability it is recommended applications use no more than 10 tags, + /// and stick to alphanumeric characters, and `+ - = . _ : /` + /// + /// + /// + pub fn push(&mut self, key: &str, value: &str) { + Serializer::new(&mut self.0).append_pair(key, value); + } + + /// Return this [`TagSet`] as a URL-encoded string + pub fn encoded(&self) -> &str { + &self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tag_set() { + let mut set = TagSet::default(); + set.push("test/foo", "value sdlks"); + set.push("foo", " sdf _ /+./sd"); + assert_eq!( + set.encoded(), + "test%2Ffoo=value+sdlks&foo=+sdf+_+%2F%2B.%2Fsd" + ); + } +} diff --git a/src/throttle.rs b/src/throttle.rs new file mode 100644 index 0000000..6586ba9 --- /dev/null +++ b/src/throttle.rs @@ -0,0 +1,658 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A throttling object store wrapper +use parking_lot::Mutex; +use std::ops::Range; +use std::{convert::TryInto, sync::Arc}; + +use crate::multipart::{MultipartStore, PartId}; +use crate::{ + path::Path, GetResult, GetResultPayload, ListResult, MultipartId, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, +}; +use crate::{GetOptions, UploadPart}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::{stream::BoxStream, FutureExt, StreamExt}; +use std::time::Duration; + +/// Configuration settings for throttled store +#[derive(Debug, Default, Clone, Copy)] +pub struct ThrottleConfig { + /// Sleep duration for every call to [`delete`](ThrottledStore::delete). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. + pub wait_delete_per_call: Duration, + + /// Sleep duration for every byte received during [`get`](ThrottledStore::get). + /// + /// Sleeping is performed after the underlying store returned and only for successful gets. The + /// sleep duration is additive to [`wait_get_per_call`](Self::wait_get_per_call). + /// + /// Note that the per-byte sleep only happens as the user consumes the output bytes. Should + /// there be an intermediate failure (i.e. after partly consuming the output bytes), the + /// resulting sleep time will be partial as well. + pub wait_get_per_byte: Duration, + + /// Sleep duration for every call to [`get`](ThrottledStore::get). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. The sleep duration is additive to + /// [`wait_get_per_byte`](Self::wait_get_per_byte). + pub wait_get_per_call: Duration, + + /// Sleep duration for every call to [`list`](ThrottledStore::list). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. The sleep duration is additive to + /// [`wait_list_per_entry`](Self::wait_list_per_entry). + pub wait_list_per_call: Duration, + + /// Sleep duration for every entry received during [`list`](ThrottledStore::list). + /// + /// Sleeping is performed after the underlying store returned and only for successful lists. + /// The sleep duration is additive to [`wait_list_per_call`](Self::wait_list_per_call). + /// + /// Note that the per-entry sleep only happens as the user consumes the output entries. Should + /// there be an intermediate failure (i.e. after partly consuming the output entries), the + /// resulting sleep time will be partial as well. + pub wait_list_per_entry: Duration, + + /// Sleep duration for every call to + /// [`list_with_delimiter`](ThrottledStore::list_with_delimiter). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. The sleep duration is additive to + /// [`wait_list_with_delimiter_per_entry`](Self::wait_list_with_delimiter_per_entry). + pub wait_list_with_delimiter_per_call: Duration, + + /// Sleep duration for every entry received during + /// [`list_with_delimiter`](ThrottledStore::list_with_delimiter). + /// + /// Sleeping is performed after the underlying store returned and only for successful gets. The + /// sleep duration is additive to + /// [`wait_list_with_delimiter_per_call`](Self::wait_list_with_delimiter_per_call). + pub wait_list_with_delimiter_per_entry: Duration, + + /// Sleep duration for every call to [`put`](ThrottledStore::put). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. + pub wait_put_per_call: Duration, +} + +/// Sleep only if non-zero duration +async fn sleep(duration: Duration) { + if !duration.is_zero() { + tokio::time::sleep(duration).await + } +} + +/// Store wrapper that wraps an inner store with some `sleep` calls. +/// +/// This can be used for performance testing. +/// +/// **Note that the behavior of the wrapper is deterministic and might not reflect real-world +/// conditions!** +#[derive(Debug)] +pub struct ThrottledStore { + inner: T, + config: Arc>, +} + +impl ThrottledStore { + /// Create new wrapper with zero waiting times. + pub fn new(inner: T, config: ThrottleConfig) -> Self { + Self { + inner, + config: Arc::new(Mutex::new(config)), + } + } + + /// Mutate config. + pub fn config_mut(&self, f: F) + where + F: Fn(&mut ThrottleConfig), + { + let mut guard = self.config.lock(); + f(&mut guard) + } + + /// Return copy of current config. + pub fn config(&self) -> ThrottleConfig { + *self.config.lock() + } +} + +impl std::fmt::Display for ThrottledStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ThrottledStore({})", self.inner) + } +} + +#[async_trait] +impl ObjectStore for ThrottledStore { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.put(location, payload).await + } + + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart(&self, location: &Path) -> Result> { + let upload = self.inner.put_multipart(location).await?; + Ok(Box::new(ThrottledUpload { + upload, + sleep: self.config().wait_put_per_call, + })) + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> Result> { + let upload = self.inner.put_multipart_opts(location, opts).await?; + Ok(Box::new(ThrottledUpload { + upload, + sleep: self.config().wait_put_per_call, + })) + } + + async fn get(&self, location: &Path) -> Result { + sleep(self.config().wait_get_per_call).await; + + // need to copy to avoid moving / referencing `self` + let wait_get_per_byte = self.config().wait_get_per_byte; + + let result = self.inner.get(location).await?; + Ok(throttle_get(result, wait_get_per_byte)) + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + sleep(self.config().wait_get_per_call).await; + + // need to copy to avoid moving / referencing `self` + let wait_get_per_byte = self.config().wait_get_per_byte; + + let result = self.inner.get_opts(location, options).await?; + Ok(throttle_get(result, wait_get_per_byte)) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let config = self.config(); + + let sleep_duration = + config.wait_get_per_call + config.wait_get_per_byte * (range.end - range.start) as u32; + + sleep(sleep_duration).await; + + self.inner.get_range(location, range).await + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> Result> { + let config = self.config(); + + let total_bytes: u64 = ranges.iter().map(|range| range.end - range.start).sum(); + let sleep_duration = + config.wait_get_per_call + config.wait_get_per_byte * total_bytes as u32; + + sleep(sleep_duration).await; + + self.inner.get_ranges(location, ranges).await + } + + async fn head(&self, location: &Path) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.head(location).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + sleep(self.config().wait_delete_per_call).await; + + self.inner.delete(location).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let stream = self.inner.list(prefix); + let config = Arc::clone(&self.config); + futures::stream::once(async move { + let config = *config.lock(); + let wait_list_per_entry = config.wait_list_per_entry; + sleep(config.wait_list_per_call).await; + throttle_stream(stream, move |_| wait_list_per_entry) + }) + .flatten() + .boxed() + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, Result> { + let stream = self.inner.list_with_offset(prefix, offset); + let config = Arc::clone(&self.config); + futures::stream::once(async move { + let config = *config.lock(); + let wait_list_per_entry = config.wait_list_per_entry; + sleep(config.wait_list_per_call).await; + throttle_stream(stream, move |_| wait_list_per_entry) + }) + .flatten() + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + sleep(self.config().wait_list_with_delimiter_per_call).await; + + match self.inner.list_with_delimiter(prefix).await { + Ok(list_result) => { + let entries_len = usize_to_u32_saturate(list_result.objects.len()); + sleep(self.config().wait_list_with_delimiter_per_entry * entries_len).await; + Ok(list_result) + } + Err(err) => Err(err), + } + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.rename(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.copy_if_not_exists(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.rename_if_not_exists(from, to).await + } +} + +/// Saturated `usize` to `u32` cast. +fn usize_to_u32_saturate(x: usize) -> u32 { + x.try_into().unwrap_or(u32::MAX) +} + +fn throttle_get(result: GetResult, wait_get_per_byte: Duration) -> GetResult { + #[allow(clippy::infallible_destructuring_match)] + let s = match result.payload { + GetResultPayload::Stream(s) => s, + #[cfg(all(feature = "fs", not(target_arch = "wasm32")))] + GetResultPayload::File(_, _) => unimplemented!(), + }; + + let stream = throttle_stream(s, move |bytes| { + let bytes_len: u32 = usize_to_u32_saturate(bytes.len()); + wait_get_per_byte * bytes_len + }); + + GetResult { + payload: GetResultPayload::Stream(stream), + ..result + } +} + +fn throttle_stream( + stream: BoxStream<'_, Result>, + delay: F, +) -> BoxStream<'_, Result> +where + F: Fn(&T) -> Duration + Send + Sync + 'static, +{ + stream + .then(move |result| { + let delay = result.as_ref().ok().map(&delay).unwrap_or_default(); + sleep(delay).then(|_| futures::future::ready(result)) + }) + .boxed() +} + +#[async_trait] +impl MultipartStore for ThrottledStore { + async fn create_multipart(&self, path: &Path) -> Result { + self.inner.create_multipart(path).await + } + + async fn put_part( + &self, + path: &Path, + id: &MultipartId, + part_idx: usize, + data: PutPayload, + ) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.put_part(path, id, part_idx, data).await + } + + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + parts: Vec, + ) -> Result { + self.inner.complete_multipart(path, id, parts).await + } + + async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()> { + self.inner.abort_multipart(path, id).await + } +} + +#[derive(Debug)] +struct ThrottledUpload { + upload: Box, + sleep: Duration, +} + +#[async_trait] +impl MultipartUpload for ThrottledUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let duration = self.sleep; + let put = self.upload.put_part(data); + Box::pin(async move { + sleep(duration).await; + put.await + }) + } + + async fn complete(&mut self) -> Result { + self.upload.complete().await + } + + async fn abort(&mut self) -> Result<()> { + self.upload.abort().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{integration::*, memory::InMemory, GetResultPayload}; + use futures::TryStreamExt; + use tokio::time::Duration; + use tokio::time::Instant; + + const WAIT_TIME: Duration = Duration::from_millis(100); + const ZERO: Duration = Duration::from_millis(0); // Duration::default isn't constant + + macro_rules! assert_bounds { + ($d:expr, $lower:expr) => { + assert_bounds!($d, $lower, $lower + 2); + }; + ($d:expr, $lower:expr, $upper:expr) => { + let d = $d; + let lower = $lower * WAIT_TIME; + let upper = $upper * WAIT_TIME; + assert!(d >= lower, "{:?} must be >= than {:?}", d, lower); + assert!(d < upper, "{:?} must be < than {:?}", d, upper); + }; + } + + #[tokio::test] + async fn throttle_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + put_get_delete_list(&store).await; + list_uses_directories_correctly(&store).await; + list_with_delimiter(&store).await; + rename_and_copy(&store).await; + copy_if_not_exists(&store).await; + stream_get(&store).await; + multipart(&store, &store).await; + } + + #[tokio::test] + async fn delete_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_delete(&store, None).await, 0); + assert_bounds!(measure_delete(&store, Some(0)).await, 0); + assert_bounds!(measure_delete(&store, Some(10)).await, 0); + + store.config_mut(|cfg| cfg.wait_delete_per_call = WAIT_TIME); + assert_bounds!(measure_delete(&store, None).await, 1); + assert_bounds!(measure_delete(&store, Some(0)).await, 1); + assert_bounds!(measure_delete(&store, Some(10)).await, 1); + } + + #[tokio::test] + // macos github runner is so slow it can't complete within WAIT_TIME*2 + #[cfg(target_os = "linux")] + async fn get_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_get(&store, None).await, 0); + assert_bounds!(measure_get(&store, Some(0)).await, 0); + assert_bounds!(measure_get(&store, Some(10)).await, 0); + + store.config_mut(|cfg| cfg.wait_get_per_call = WAIT_TIME); + assert_bounds!(measure_get(&store, None).await, 1); + assert_bounds!(measure_get(&store, Some(0)).await, 1); + assert_bounds!(measure_get(&store, Some(10)).await, 1); + + store.config_mut(|cfg| { + cfg.wait_get_per_call = ZERO; + cfg.wait_get_per_byte = WAIT_TIME; + }); + assert_bounds!(measure_get(&store, Some(2)).await, 2); + + store.config_mut(|cfg| { + cfg.wait_get_per_call = WAIT_TIME; + cfg.wait_get_per_byte = WAIT_TIME; + }); + assert_bounds!(measure_get(&store, Some(2)).await, 3); + } + + #[tokio::test] + // macos github runner is so slow it can't complete within WAIT_TIME*2 + #[cfg(target_os = "linux")] + async fn list_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_list(&store, 0).await, 0); + assert_bounds!(measure_list(&store, 10).await, 0); + + store.config_mut(|cfg| cfg.wait_list_per_call = WAIT_TIME); + assert_bounds!(measure_list(&store, 0).await, 1); + assert_bounds!(measure_list(&store, 10).await, 1); + + store.config_mut(|cfg| { + cfg.wait_list_per_call = ZERO; + cfg.wait_list_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list(&store, 2).await, 2); + + store.config_mut(|cfg| { + cfg.wait_list_per_call = WAIT_TIME; + cfg.wait_list_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list(&store, 2).await, 3); + } + + #[tokio::test] + // macos github runner is so slow it can't complete within WAIT_TIME*2 + #[cfg(target_os = "linux")] + async fn list_with_delimiter_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_list_with_delimiter(&store, 0).await, 0); + assert_bounds!(measure_list_with_delimiter(&store, 10).await, 0); + + store.config_mut(|cfg| cfg.wait_list_with_delimiter_per_call = WAIT_TIME); + assert_bounds!(measure_list_with_delimiter(&store, 0).await, 1); + assert_bounds!(measure_list_with_delimiter(&store, 10).await, 1); + + store.config_mut(|cfg| { + cfg.wait_list_with_delimiter_per_call = ZERO; + cfg.wait_list_with_delimiter_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list_with_delimiter(&store, 2).await, 2); + + store.config_mut(|cfg| { + cfg.wait_list_with_delimiter_per_call = WAIT_TIME; + cfg.wait_list_with_delimiter_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list_with_delimiter(&store, 2).await, 3); + } + + #[tokio::test] + async fn put_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_put(&store, 0).await, 0); + assert_bounds!(measure_put(&store, 10).await, 0); + + store.config_mut(|cfg| cfg.wait_put_per_call = WAIT_TIME); + assert_bounds!(measure_put(&store, 0).await, 1); + assert_bounds!(measure_put(&store, 10).await, 1); + + store.config_mut(|cfg| cfg.wait_put_per_call = ZERO); + assert_bounds!(measure_put(&store, 0).await, 0); + } + + async fn place_test_object(store: &ThrottledStore, n_bytes: Option) -> Path { + let path = Path::from("foo"); + + if let Some(n_bytes) = n_bytes { + let data: Vec<_> = std::iter::repeat(1u8).take(n_bytes).collect(); + store.put(&path, data.into()).await.unwrap(); + } else { + // ensure object is absent + store.delete(&path).await.unwrap(); + } + + path + } + + #[allow(dead_code)] + async fn place_test_objects(store: &ThrottledStore, n_entries: usize) -> Path { + let prefix = Path::from("foo"); + + // clean up store + let entries: Vec<_> = store.list(Some(&prefix)).try_collect().await.unwrap(); + + for entry in entries { + store.delete(&entry.location).await.unwrap(); + } + + // create new entries + for i in 0..n_entries { + let path = prefix.child(i.to_string().as_str()); + store.put(&path, "bar".into()).await.unwrap(); + } + + prefix + } + + async fn measure_delete(store: &ThrottledStore, n_bytes: Option) -> Duration { + let path = place_test_object(store, n_bytes).await; + + let t0 = Instant::now(); + store.delete(&path).await.unwrap(); + + t0.elapsed() + } + + #[allow(dead_code)] + async fn measure_get(store: &ThrottledStore, n_bytes: Option) -> Duration { + let path = place_test_object(store, n_bytes).await; + + let t0 = Instant::now(); + let res = store.get(&path).await; + if n_bytes.is_some() { + // need to consume bytes to provoke sleep times + let s = match res.unwrap().payload { + GetResultPayload::Stream(s) => s, + GetResultPayload::File(_, _) => unimplemented!(), + }; + + s.map_ok(|b| bytes::BytesMut::from(&b[..])) + .try_concat() + .await + .unwrap(); + } else { + assert!(res.is_err()); + } + + t0.elapsed() + } + + #[allow(dead_code)] + async fn measure_list(store: &ThrottledStore, n_entries: usize) -> Duration { + let prefix = place_test_objects(store, n_entries).await; + + let t0 = Instant::now(); + store + .list(Some(&prefix)) + .try_collect::>() + .await + .unwrap(); + + t0.elapsed() + } + + #[allow(dead_code)] + async fn measure_list_with_delimiter( + store: &ThrottledStore, + n_entries: usize, + ) -> Duration { + let prefix = place_test_objects(store, n_entries).await; + + let t0 = Instant::now(); + store.list_with_delimiter(Some(&prefix)).await.unwrap(); + + t0.elapsed() + } + + async fn measure_put(store: &ThrottledStore, n_bytes: usize) -> Duration { + let data: Vec<_> = std::iter::repeat(1u8).take(n_bytes).collect(); + + let t0 = Instant::now(); + store.put(&Path::from("foo"), data.into()).await.unwrap(); + + t0.elapsed() + } +} diff --git a/src/upload.rs b/src/upload.rs new file mode 100644 index 0000000..4df4d8f --- /dev/null +++ b/src/upload.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::task::{Context, Poll}; + +use crate::{PutPayload, PutPayloadMut, PutResult, Result}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::ready; +use tokio::task::JoinSet; + +/// An upload part request +pub type UploadPart = BoxFuture<'static, Result<()>>; + +/// A trait allowing writing an object in fixed size chunks +/// +/// Consecutive chunks of data can be written by calling [`MultipartUpload::put_part`] and polling +/// the returned futures to completion. Multiple futures returned by [`MultipartUpload::put_part`] +/// may be polled in parallel, allowing for concurrent uploads. +/// +/// Once all part uploads have been polled to completion, the upload can be completed by +/// calling [`MultipartUpload::complete`]. This will make the entire uploaded object visible +/// as an atomic operation.It is implementation behind behaviour if [`MultipartUpload::complete`] +/// is called before all [`UploadPart`] have been polled to completion. +#[async_trait] +pub trait MultipartUpload: Send + std::fmt::Debug { + /// Upload the next part + /// + /// Most stores require that all parts excluding the last are at least 5 MiB, and some + /// further require that all parts excluding the last be the same size, e.g. [R2]. + /// Clients wanting to maximise compatibility should therefore perform writes in + /// fixed size blocks larger than 5 MiB. + /// + /// Implementations may invoke this method multiple times and then await on the + /// returned futures in parallel + /// + /// ```no_run + /// # use futures::StreamExt; + /// # use object_store::MultipartUpload; + /// # + /// # async fn test() { + /// # + /// let mut upload: Box<&dyn MultipartUpload> = todo!(); + /// let p1 = upload.put_part(vec![0; 10 * 1024 * 1024].into()); + /// let p2 = upload.put_part(vec![1; 10 * 1024 * 1024].into()); + /// futures::future::try_join(p1, p2).await.unwrap(); + /// upload.complete().await.unwrap(); + /// # } + /// ``` + /// + /// [R2]: https://developers.cloudflare.com/r2/objects/multipart-objects/#limitations + fn put_part(&mut self, data: PutPayload) -> UploadPart; + + /// Complete the multipart upload + /// + /// It is implementation defined behaviour if this method is called before polling + /// all [`UploadPart`] returned by [`MultipartUpload::put_part`] to completion. Additionally, + /// it is implementation defined behaviour to call [`MultipartUpload::complete`] + /// on an already completed or aborted [`MultipartUpload`]. + async fn complete(&mut self) -> Result; + + /// Abort the multipart upload + /// + /// If a [`MultipartUpload`] is dropped without calling [`MultipartUpload::complete`], + /// some object stores will automatically clean up any previously uploaded parts. + /// However, some stores, such as S3 and GCS, cannot perform cleanup on drop. + /// As such [`MultipartUpload::abort`] can be invoked to perform this cleanup. + /// + /// It will not be possible to call `abort` in all failure scenarios, for example + /// non-graceful shutdown of the calling application. It is therefore recommended + /// object stores are configured with lifecycle rules to automatically cleanup + /// unused parts older than some threshold. See [crate::aws] and [crate::gcp] + /// for more information. + /// + /// It is implementation defined behaviour to call [`MultipartUpload::abort`] + /// on an already completed or aborted [`MultipartUpload`] + async fn abort(&mut self) -> Result<()>; +} + +#[async_trait] +impl MultipartUpload for Box { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + (**self).put_part(data) + } + + async fn complete(&mut self) -> Result { + (**self).complete().await + } + + async fn abort(&mut self) -> Result<()> { + (**self).abort().await + } +} + +/// A synchronous write API for uploading data in parallel in fixed size chunks +/// +/// Uses multiple tokio tasks in a [`JoinSet`] to multiplex upload tasks in parallel +/// +/// The design also takes inspiration from [`Sink`] with [`WriteMultipart::wait_for_capacity`] +/// allowing back pressure on producers, prior to buffering the next part. However, unlike +/// [`Sink`] this back pressure is optional, allowing integration with synchronous producers +/// +/// [`Sink`]: futures::sink::Sink +#[derive(Debug)] +pub struct WriteMultipart { + upload: Box, + + buffer: PutPayloadMut, + + chunk_size: usize, + + tasks: JoinSet>, +} + +impl WriteMultipart { + /// Create a new [`WriteMultipart`] that will upload using 5MB chunks + pub fn new(upload: Box) -> Self { + Self::new_with_chunk_size(upload, 5 * 1024 * 1024) + } + + /// Create a new [`WriteMultipart`] that will upload in fixed `chunk_size` sized chunks + pub fn new_with_chunk_size(upload: Box, chunk_size: usize) -> Self { + Self { + upload, + chunk_size, + buffer: PutPayloadMut::new(), + tasks: Default::default(), + } + } + + /// Polls for there to be less than `max_concurrency` [`UploadPart`] in progress + /// + /// See [`Self::wait_for_capacity`] for an async version of this function + pub fn poll_for_capacity( + &mut self, + cx: &mut Context<'_>, + max_concurrency: usize, + ) -> Poll> { + while !self.tasks.is_empty() && self.tasks.len() >= max_concurrency { + ready!(self.tasks.poll_join_next(cx)).unwrap()?? + } + Poll::Ready(Ok(())) + } + + /// Wait until there are less than `max_concurrency` [`UploadPart`] in progress + /// + /// See [`Self::poll_for_capacity`] for a [`Poll`] version of this function + pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> Result<()> { + futures::future::poll_fn(|cx| self.poll_for_capacity(cx, max_concurrency)).await + } + + /// Write data to this [`WriteMultipart`] + /// + /// Data is buffered using [`PutPayloadMut::extend_from_slice`]. Implementations looking to + /// write data from owned buffers may prefer [`Self::put`] as this avoids copying. + /// + /// Note this method is synchronous (not `async`) and will immediately + /// start new uploads as soon as the internal `chunk_size` is hit, + /// regardless of how many outstanding uploads are already in progress. + /// + /// Back pressure can optionally be applied to producers by calling + /// [`Self::wait_for_capacity`] prior to calling this method + pub fn write(&mut self, mut buf: &[u8]) { + while !buf.is_empty() { + let remaining = self.chunk_size - self.buffer.content_length(); + let to_read = buf.len().min(remaining); + self.buffer.extend_from_slice(&buf[..to_read]); + if to_read == remaining { + let buffer = std::mem::take(&mut self.buffer); + self.put_part(buffer.into()) + } + buf = &buf[to_read..] + } + } + + /// Put a chunk of data into this [`WriteMultipart`] without copying + /// + /// Data is buffered using [`PutPayloadMut::push`]. Implementations looking to + /// perform writes from non-owned buffers should prefer [`Self::write`] as this + /// will allow multiple calls to share the same underlying allocation. + /// + /// See [`Self::write`] for information on backpressure + pub fn put(&mut self, mut bytes: Bytes) { + while !bytes.is_empty() { + let remaining = self.chunk_size - self.buffer.content_length(); + if bytes.len() < remaining { + self.buffer.push(bytes); + return; + } + self.buffer.push(bytes.split_to(remaining)); + let buffer = std::mem::take(&mut self.buffer); + self.put_part(buffer.into()) + } + } + + pub(crate) fn put_part(&mut self, part: PutPayload) { + self.tasks.spawn(self.upload.put_part(part)); + } + + /// Abort this upload, attempting to clean up any successfully uploaded parts + pub async fn abort(mut self) -> Result<()> { + self.tasks.shutdown().await; + self.upload.abort().await + } + + /// Flush final chunk, and await completion of all in-flight requests + pub async fn finish(mut self) -> Result { + if !self.buffer.is_empty() { + let part = std::mem::take(&mut self.buffer); + self.put_part(part.into()) + } + + self.wait_for_capacity(0).await?; + + match self.upload.complete().await { + Err(e) => { + self.tasks.shutdown().await; + self.upload.abort().await?; + Err(e) + } + Ok(result) => Ok(result), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use futures::FutureExt; + use parking_lot::Mutex; + use rand::prelude::StdRng; + use rand::{Rng, SeedableRng}; + + use crate::memory::InMemory; + use crate::path::Path; + use crate::throttle::{ThrottleConfig, ThrottledStore}; + use crate::ObjectStore; + + use super::*; + + #[tokio::test] + async fn test_concurrency() { + let config = ThrottleConfig { + wait_put_per_call: Duration::from_millis(1), + ..Default::default() + }; + + let path = Path::from("foo"); + let store = ThrottledStore::new(InMemory::new(), config); + let upload = store.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new_with_chunk_size(upload, 10); + + for _ in 0..20 { + write.write(&[0; 5]); + } + assert!(write.wait_for_capacity(10).now_or_never().is_none()); + write.wait_for_capacity(10).await.unwrap() + } + + #[derive(Debug, Default)] + struct InstrumentedUpload { + chunks: Arc>>, + } + + #[async_trait] + impl MultipartUpload for InstrumentedUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + self.chunks.lock().push(data); + futures::future::ready(Ok(())).boxed() + } + + async fn complete(&mut self) -> Result { + Ok(PutResult { + e_tag: None, + version: None, + }) + } + + async fn abort(&mut self) -> Result<()> { + unimplemented!() + } + } + + #[tokio::test] + async fn test_write_multipart() { + let mut rng = StdRng::seed_from_u64(42); + + for method in [0.0, 0.5, 1.0] { + for _ in 0..10 { + for chunk_size in [1, 17, 23] { + let upload = Box::::default(); + let chunks = Arc::clone(&upload.chunks); + let mut write = WriteMultipart::new_with_chunk_size(upload, chunk_size); + + let mut expected = Vec::with_capacity(1024); + + for _ in 0..50 { + let chunk_size = rng.gen_range(0..30); + let data: Vec<_> = (0..chunk_size).map(|_| rng.gen()).collect(); + expected.extend_from_slice(&data); + + match rng.gen_bool(method) { + true => write.put(data.into()), + false => write.write(&data), + } + } + write.finish().await.unwrap(); + + let chunks = chunks.lock(); + + let actual: Vec<_> = chunks.iter().flatten().flatten().copied().collect(); + assert_eq!(expected, actual); + + for chunk in chunks.iter().take(chunks.len() - 1) { + assert_eq!(chunk.content_length(), chunk_size) + } + + let last_chunk = chunks.last().unwrap().content_length(); + assert!(last_chunk <= chunk_size, "{chunk_size}"); + } + } + } + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..17a7a8c --- /dev/null +++ b/src/util.rs @@ -0,0 +1,489 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common logic for interacting with remote object stores +use std::{ + fmt::Display, + ops::{Range, RangeBounds}, +}; + +use super::Result; +use bytes::Bytes; +use futures::{stream::StreamExt, Stream, TryStreamExt}; + +#[cfg(any(feature = "azure", feature = "http"))] +pub(crate) static RFC1123_FMT: &str = "%a, %d %h %Y %T GMT"; + +// deserialize dates according to rfc1123 +#[cfg(any(feature = "azure", feature = "http"))] +pub(crate) fn deserialize_rfc1123<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: String = serde::Deserialize::deserialize(deserializer)?; + let naive = + chrono::NaiveDateTime::parse_from_str(&s, RFC1123_FMT).map_err(serde::de::Error::custom)?; + Ok(chrono::TimeZone::from_utc_datetime(&chrono::Utc, &naive)) +} + +#[cfg(any(feature = "aws", feature = "azure"))] +pub(crate) fn hmac_sha256(secret: impl AsRef<[u8]>, bytes: impl AsRef<[u8]>) -> ring::hmac::Tag { + let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_ref()); + ring::hmac::sign(&key, bytes.as_ref()) +} + +/// Collect a stream into [`Bytes`] avoiding copying in the event of a single chunk +pub async fn collect_bytes(mut stream: S, size_hint: Option) -> Result +where + E: Send, + S: Stream> + Send + Unpin, +{ + let first = stream.next().await.transpose()?.unwrap_or_default(); + + // Avoid copying if single response + match stream.next().await.transpose()? { + None => Ok(first), + Some(second) => { + let size_hint = size_hint.unwrap_or_else(|| first.len() as u64 + second.len() as u64); + + let mut buf = Vec::with_capacity(size_hint as usize); + buf.extend_from_slice(&first); + buf.extend_from_slice(&second); + while let Some(maybe_bytes) = stream.next().await { + buf.extend_from_slice(&maybe_bytes?); + } + + Ok(buf.into()) + } + } +} + +#[cfg(all(feature = "fs", not(target_arch = "wasm32")))] +/// Takes a function and spawns it to a tokio blocking pool if available +pub(crate) async fn maybe_spawn_blocking(f: F) -> Result +where + F: FnOnce() -> Result + Send + 'static, + T: Send + 'static, +{ + match tokio::runtime::Handle::try_current() { + Ok(runtime) => runtime.spawn_blocking(f).await?, + Err(_) => f(), + } +} + +/// Range requests with a gap less than or equal to this, +/// will be coalesced into a single request by [`coalesce_ranges`] +pub const OBJECT_STORE_COALESCE_DEFAULT: u64 = 1024 * 1024; + +/// Up to this number of range requests will be performed in parallel by [`coalesce_ranges`] +pub(crate) const OBJECT_STORE_COALESCE_PARALLEL: usize = 10; + +/// Takes a function `fetch` that can fetch a range of bytes and uses this to +/// fetch the provided byte `ranges` +/// +/// To improve performance it will: +/// +/// * Combine ranges less than `coalesce` bytes apart into a single call to `fetch` +/// * Make multiple `fetch` requests in parallel (up to maximum of 10) +/// +pub async fn coalesce_ranges( + ranges: &[Range], + fetch: F, + coalesce: u64, +) -> Result, E> +where + F: Send + FnMut(Range) -> Fut, + E: Send, + Fut: std::future::Future> + Send, +{ + let fetch_ranges = merge_ranges(ranges, coalesce); + + let fetched: Vec<_> = futures::stream::iter(fetch_ranges.iter().cloned()) + .map(fetch) + .buffered(OBJECT_STORE_COALESCE_PARALLEL) + .try_collect() + .await?; + + Ok(ranges + .iter() + .map(|range| { + let idx = fetch_ranges.partition_point(|v| v.start <= range.start) - 1; + let fetch_range = &fetch_ranges[idx]; + let fetch_bytes = &fetched[idx]; + + let start = range.start - fetch_range.start; + let end = range.end - fetch_range.start; + let range = (start as usize)..(end as usize).min(fetch_bytes.len()); + fetch_bytes.slice(range) + }) + .collect()) +} + +/// Returns a sorted list of ranges that cover `ranges` +fn merge_ranges(ranges: &[Range], coalesce: u64) -> Vec> { + if ranges.is_empty() { + return vec![]; + } + + let mut ranges = ranges.to_vec(); + ranges.sort_unstable_by_key(|range| range.start); + + let mut ret = Vec::with_capacity(ranges.len()); + let mut start_idx = 0; + let mut end_idx = 1; + + while start_idx != ranges.len() { + let mut range_end = ranges[start_idx].end; + + while end_idx != ranges.len() + && ranges[end_idx] + .start + .checked_sub(range_end) + .map(|delta| delta <= coalesce) + .unwrap_or(true) + { + range_end = range_end.max(ranges[end_idx].end); + end_idx += 1; + } + + let start = ranges[start_idx].start; + let end = range_end; + ret.push(start..end); + + start_idx = end_idx; + end_idx += 1; + } + + ret +} + +/// Request only a portion of an object's bytes +/// +/// These can be created from [usize] ranges, like +/// +/// ```rust +/// # use object_store::GetRange; +/// let range1: GetRange = (50..150).into(); +/// let range2: GetRange = (50..=150).into(); +/// let range3: GetRange = (50..).into(); +/// let range4: GetRange = (..150).into(); +/// ``` +/// +/// Implementations may wish to inspect [`GetResult`] for the exact byte +/// range returned. +/// +/// [`GetResult`]: crate::GetResult +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum GetRange { + /// Request a specific range of bytes + /// + /// If the given range is zero-length or starts after the end of the object, + /// an error will be returned. Additionally, if the range ends after the end + /// of the object, the entire remainder of the object will be returned. + /// Otherwise, the exact requested range will be returned. + /// + /// Note that range is u64 (i.e., not usize), + /// as `object_store` supports 32-bit architectures such as WASM + Bounded(Range), + /// Request all bytes starting from a given byte offset + Offset(u64), + /// Request up to the last n bytes + Suffix(u64), +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum InvalidGetRange { + #[error("Wanted range starting at {requested}, but object was only {length} bytes long")] + StartTooLarge { requested: u64, length: u64 }, + + #[error("Range started at {start} and ended at {end}")] + Inconsistent { start: u64, end: u64 }, + + #[error("Range {requested} is larger than system memory limit {max}")] + TooLarge { requested: u64, max: u64 }, +} + +impl GetRange { + pub(crate) fn is_valid(&self) -> Result<(), InvalidGetRange> { + if let Self::Bounded(r) = self { + if r.end <= r.start { + return Err(InvalidGetRange::Inconsistent { + start: r.start, + end: r.end, + }); + } + if (r.end - r.start) > usize::MAX as u64 { + return Err(InvalidGetRange::TooLarge { + requested: r.start, + max: usize::MAX as u64, + }); + } + } + Ok(()) + } + + /// Convert to a [`Range`] if valid. + pub(crate) fn as_range(&self, len: u64) -> Result, InvalidGetRange> { + self.is_valid()?; + match self { + Self::Bounded(r) => { + if r.start >= len { + Err(InvalidGetRange::StartTooLarge { + requested: r.start, + length: len, + }) + } else if r.end > len { + Ok(r.start..len) + } else { + Ok(r.clone()) + } + } + Self::Offset(o) => { + if *o >= len { + Err(InvalidGetRange::StartTooLarge { + requested: *o, + length: len, + }) + } else { + Ok(*o..len) + } + } + Self::Suffix(n) => Ok(len.saturating_sub(*n)..len), + } + } +} + +impl Display for GetRange { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bounded(r) => write!(f, "bytes={}-{}", r.start, r.end - 1), + Self::Offset(o) => write!(f, "bytes={o}-"), + Self::Suffix(n) => write!(f, "bytes=-{n}"), + } + } +} + +impl> From for GetRange { + fn from(value: T) -> Self { + use std::ops::Bound::*; + let first = match value.start_bound() { + Included(i) => *i, + Excluded(i) => i + 1, + Unbounded => 0, + }; + match value.end_bound() { + Included(i) => Self::Bounded(first..(i + 1)), + Excluded(i) => Self::Bounded(first..*i), + Unbounded => Self::Offset(first), + } + } +} +// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html +// +// Do not URI-encode any of the unreserved characters that RFC 3986 defines: +// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ). +#[cfg(any(feature = "aws", feature = "gcp"))] +pub(crate) const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC + .remove(b'-') + .remove(b'.') + .remove(b'_') + .remove(b'~'); + +/// Computes the SHA256 digest of `body` returned as a hex encoded string +#[cfg(any(feature = "aws", feature = "gcp"))] +pub(crate) fn hex_digest(bytes: &[u8]) -> String { + let digest = ring::digest::digest(&ring::digest::SHA256, bytes); + hex_encode(digest.as_ref()) +} + +/// Returns `bytes` as a lower-case hex encoded string +#[cfg(any(feature = "aws", feature = "gcp"))] +pub(crate) fn hex_encode(bytes: &[u8]) -> String { + use std::fmt::Write; + let mut out = String::with_capacity(bytes.len() * 2); + for byte in bytes { + // String writing is infallible + let _ = write!(out, "{byte:02x}"); + } + out +} + +#[cfg(test)] +mod tests { + use crate::Error; + + use super::*; + use rand::{thread_rng, Rng}; + use std::ops::Range; + + /// Calls coalesce_ranges and validates the returned data is correct + /// + /// Returns the fetched ranges + async fn do_fetch(ranges: Vec>, coalesce: u64) -> Vec> { + let max = ranges.iter().map(|x| x.end).max().unwrap_or(0); + let src: Vec<_> = (0..max).map(|x| x as u8).collect(); + + let mut fetches = vec![]; + let coalesced = coalesce_ranges::<_, Error, _>( + &ranges, + |range| { + fetches.push(range.clone()); + let start = usize::try_from(range.start).unwrap(); + let end = usize::try_from(range.end).unwrap(); + futures::future::ready(Ok(Bytes::from(src[start..end].to_vec()))) + }, + coalesce, + ) + .await + .unwrap(); + + assert_eq!(ranges.len(), coalesced.len()); + for (range, bytes) in ranges.iter().zip(coalesced) { + assert_eq!( + bytes.as_ref(), + &src[usize::try_from(range.start).unwrap()..usize::try_from(range.end).unwrap()] + ); + } + fetches + } + + #[tokio::test] + async fn test_coalesce_ranges() { + let fetches = do_fetch(vec![], 0).await; + assert!(fetches.is_empty()); + + let fetches = do_fetch(vec![0..3; 1], 0).await; + assert_eq!(fetches, vec![0..3]); + + let fetches = do_fetch(vec![0..2, 3..5], 0).await; + assert_eq!(fetches, vec![0..2, 3..5]); + + let fetches = do_fetch(vec![0..1, 1..2], 0).await; + assert_eq!(fetches, vec![0..2]); + + let fetches = do_fetch(vec![0..1, 2..72], 1).await; + assert_eq!(fetches, vec![0..72]); + + let fetches = do_fetch(vec![0..1, 56..72, 73..75], 1).await; + assert_eq!(fetches, vec![0..1, 56..75]); + + let fetches = do_fetch(vec![0..1, 5..6, 7..9, 2..3, 4..6], 1).await; + assert_eq!(fetches, vec![0..9]); + + let fetches = do_fetch(vec![0..1, 5..6, 7..9, 2..3, 4..6], 1).await; + assert_eq!(fetches, vec![0..9]); + + let fetches = do_fetch(vec![0..1, 6..7, 8..9, 10..14, 9..10], 4).await; + assert_eq!(fetches, vec![0..1, 6..14]); + } + + #[tokio::test] + async fn test_coalesce_fuzz() { + let mut rand = thread_rng(); + for _ in 0..100 { + let object_len = rand.gen_range(10..250); + let range_count = rand.gen_range(0..10); + let ranges: Vec<_> = (0..range_count) + .map(|_| { + let start = rand.gen_range(0..object_len); + let max_len = 20.min(object_len - start); + let len = rand.gen_range(0..max_len); + start..start + len + }) + .collect(); + + let coalesce = rand.gen_range(1..5); + let fetches = do_fetch(ranges.clone(), coalesce).await; + + for fetch in fetches.windows(2) { + assert!( + fetch[0].start <= fetch[1].start, + "fetches should be sorted, {:?} vs {:?}", + fetch[0], + fetch[1] + ); + + let delta = fetch[1].end - fetch[0].end; + assert!( + delta > coalesce, + "fetches should not overlap by {}, {:?} vs {:?} for {:?}", + coalesce, + fetch[0], + fetch[1], + ranges + ); + } + } + } + + #[test] + fn getrange_str() { + assert_eq!(GetRange::Offset(0).to_string(), "bytes=0-"); + assert_eq!(GetRange::Bounded(10..19).to_string(), "bytes=10-18"); + assert_eq!(GetRange::Suffix(10).to_string(), "bytes=-10"); + } + + #[test] + fn getrange_from() { + assert_eq!(Into::::into(10..15), GetRange::Bounded(10..15),); + assert_eq!(Into::::into(10..=15), GetRange::Bounded(10..16),); + assert_eq!(Into::::into(10..), GetRange::Offset(10),); + assert_eq!(Into::::into(..=15), GetRange::Bounded(0..16)); + } + + #[test] + fn test_as_range() { + let range = GetRange::Bounded(2..5); + assert_eq!(range.as_range(5).unwrap(), 2..5); + + let range = range.as_range(4).unwrap(); + assert_eq!(range, 2..4); + + let range = GetRange::Bounded(3..3); + let err = range.as_range(2).unwrap_err().to_string(); + assert_eq!(err, "Range started at 3 and ended at 3"); + + let range = GetRange::Bounded(2..2); + let err = range.as_range(3).unwrap_err().to_string(); + assert_eq!(err, "Range started at 2 and ended at 2"); + + let range = GetRange::Suffix(3); + assert_eq!(range.as_range(3).unwrap(), 0..3); + assert_eq!(range.as_range(2).unwrap(), 0..2); + + let range = GetRange::Suffix(0); + assert_eq!(range.as_range(0).unwrap(), 0..0); + + let range = GetRange::Offset(2); + let err = range.as_range(2).unwrap_err().to_string(); + assert_eq!( + err, + "Wanted range starting at 2, but object was only 2 bytes long" + ); + + let err = range.as_range(1).unwrap_err().to_string(); + assert_eq!( + err, + "Wanted range starting at 2, but object was only 1 bytes long" + ); + + let range = GetRange::Offset(1); + assert_eq!(range.as_range(2).unwrap(), 1..2); + } +} diff --git a/tests/get_range_file.rs b/tests/get_range_file.rs new file mode 100644 index 0000000..6790c11 --- /dev/null +++ b/tests/get_range_file.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests the default implementation of get_range handles GetResult::File correctly (#4350) + +use async_trait::async_trait; +use bytes::Bytes; +use futures::stream::BoxStream; +use object_store::local::LocalFileSystem; +use object_store::path::Path; +use object_store::*; +use std::fmt::Formatter; +use tempfile::tempdir; + +#[derive(Debug)] +struct MyStore(LocalFileSystem); + +impl std::fmt::Display for MyStore { + fn fmt(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + todo!() + } +} + +#[async_trait] +impl ObjectStore for MyStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.0.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + _location: &Path, + _opts: PutMultipartOpts, + ) -> Result> { + todo!() + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.0.get_opts(location, options).await + } + + async fn delete(&self, _: &Path) -> Result<()> { + todo!() + } + + fn list(&self, _: Option<&Path>) -> BoxStream<'static, Result> { + todo!() + } + + async fn list_with_delimiter(&self, _: Option<&Path>) -> Result { + todo!() + } + + async fn copy(&self, _: &Path, _: &Path) -> Result<()> { + todo!() + } + + async fn copy_if_not_exists(&self, _: &Path, _: &Path) -> Result<()> { + todo!() + } +} + +#[tokio::test] +async fn test_get_range() { + let tmp = tempdir().unwrap(); + let store = MyStore(LocalFileSystem::new_with_prefix(tmp.path()).unwrap()); + let path = Path::from("foo"); + + let expected = Bytes::from_static(b"hello world"); + store.put(&path, expected.clone().into()).await.unwrap(); + let fetched = store.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(expected, fetched); + + for range in [0..10, 3..5, 0..expected.len() as u64] { + let data = store.get_range(&path, range.clone()).await.unwrap(); + assert_eq!( + &data[..], + &expected[range.start as usize..range.end as usize] + ) + } + + let over_range = 0..(expected.len() as u64 * 2); + let data = store.get_range(&path, over_range.clone()).await.unwrap(); + assert_eq!(&data[..], expected) +} + +/// Test that, when a requesting a range which overhangs the end of the resource, +/// the resulting [GetResult::range] reports the returned range, +/// not the requested. +#[tokio::test] +async fn test_get_opts_over_range() { + let tmp = tempdir().unwrap(); + let store = MyStore(LocalFileSystem::new_with_prefix(tmp.path()).unwrap()); + let path = Path::from("foo"); + + let expected = Bytes::from_static(b"hello world"); + store.put(&path, expected.clone().into()).await.unwrap(); + + let opts = GetOptions { + range: Some(GetRange::Bounded(0..(expected.len() as u64 * 2))), + ..Default::default() + }; + let res = store.get_opts(&path, opts).await.unwrap(); + assert_eq!(res.range, 0..expected.len() as u64); + assert_eq!(res.bytes().await.unwrap(), expected); +} diff --git a/tests/http.rs b/tests/http.rs new file mode 100644 index 0000000..a9b3145 --- /dev/null +++ b/tests/http.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests the HTTP store implementation + +#[cfg(feature = "http")] +use object_store::{http::HttpBuilder, path::Path, GetOptions, GetRange, ObjectStore}; + +/// Tests that even when reqwest has the `gzip` feature enabled, the HTTP store +/// does not error on a missing `Content-Length` header. +#[tokio::test] +#[cfg(feature = "http")] +async fn test_http_store_gzip() { + let http_store = HttpBuilder::new() + .with_url("https://raw.githubusercontent.com/apache/arrow-rs/refs/heads/main") + .build() + .unwrap(); + + let _ = http_store + .get_opts( + &Path::parse("LICENSE.txt").unwrap(), + GetOptions { + range: Some(GetRange::Bounded(0..100)), + ..Default::default() + }, + ) + .await + .unwrap(); +}