diff --git a/src/httputil.rs b/src/httputil.rs index 37d206b..c8e3818 100644 --- a/src/httputil.rs +++ b/src/httputil.rs @@ -1,5 +1,5 @@ use crate::consts::USER_AGENT; -use reqwest::{Request, Response, StatusCode}; +use reqwest::{Method, Request, Response, StatusCode}; use reqwest_middleware::{Middleware, Next}; use serde::de::DeserializeOwned; use thiserror::Error; @@ -21,10 +21,10 @@ impl Client { Ok(Client(client)) } - async fn get(&self, url: Url) -> Result { + pub(crate) async fn request(&self, method: Method, url: Url) -> Result { let r = self .0 - .get(url.clone()) + .request(method, url.clone()) .send() .await .map_err(|source| HttpError::Send { @@ -38,6 +38,14 @@ impl Client { .map_err(|source| HttpError::Status { url, source }) } + pub(crate) async fn head(&self, url: Url) -> Result { + self.request(Method::HEAD, url).await + } + + pub(crate) async fn get(&self, url: Url) -> Result { + self.request(Method::GET, url).await + } + pub(crate) async fn get_json(&self, url: Url) -> Result { self.get(url.clone()) .await? diff --git a/src/s3.rs b/src/s3.rs index d27a0ec..50f6bbb 100644 --- a/src/s3.rs +++ b/src/s3.rs @@ -1,11 +1,10 @@ -use super::consts::USER_AGENT; -use super::paths::{ParsePureDirPathError, ParsePurePathError, PureDirPath, PurePath}; +use crate::httputil::{self, BuildClientError, HttpError}; +use crate::paths::{ParsePureDirPathError, ParsePurePathError, PureDirPath, PurePath}; use async_stream::try_stream; use aws_sdk_s3::{operation::list_objects_v2::ListObjectsV2Error, types::CommonPrefix, Client}; use aws_smithy_runtime_api::client::{orchestrator::HttpResponse, result::SdkError}; use aws_smithy_types_convert::date_time::DateTimeExt; use futures_util::{Stream, TryStreamExt}; -use reqwest::ClientBuilder; use smartstring::alias::CompactString; use std::cmp::Ordering; use std::sync::Arc; @@ -492,18 +491,15 @@ pub(crate) enum TryFromAwsObjectError { // The AWS SDK currently cannot be used for this: // pub(crate) async fn get_bucket_region(bucket: &str) -> Result { - let client = ClientBuilder::new() - .user_agent(USER_AGENT) - .https_only(true) - .build() - .map_err(GetBucketRegionError::BuildClient)?; - let r = client - .head(format!("https://{bucket}.s3.amazonaws.com")) - .send() - .await - .map_err(GetBucketRegionError::Send)? - .error_for_status() - .map_err(GetBucketRegionError::Status)?; + let url_str = format!("https://{bucket}.s3.amazonaws.com"); + let url = url_str + .parse::() + .map_err(|source| GetBucketRegionError::BadUrl { + url: url_str, + source, + })?; + let client = httputil::Client::new()?; + let r = client.head(url).await?; match r.headers().get("x-amz-bucket-region").map(|hv| hv.to_str()) { Some(Ok(region)) => Ok(region.to_owned()), Some(Err(e)) => Err(GetBucketRegionError::BadHeader(e)), @@ -513,12 +509,15 @@ pub(crate) async fn get_bucket_region(bucket: &str) -> Result