Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CRT HeadersError enum to include header name #1205

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mountpoint-s3-crt/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
* Checksum hashers no longer implement `std::hash::Hasher`. ([#1082](https://github.com/awslabs/mountpoint-s3/pull/1082))
* Add bindings to remaining checksum types CRC64, SHA1, and SHA256. ([#1082](https://github.com/awslabs/mountpoint-s3/pull/1082))
* Add wrapping type `ByteBuf` for `aws_byte_buf`. ([#1082](https://github.com/awslabs/mountpoint-s3/pull/1082))
* `HeadersError::HeaderNotFound` and `HeadersError::Invalid` variants now include the name of the header.
Despite the new field being private, this may impact any code that was pattern matching on these variants.
([#1205](https://github.com/awslabs/mountpoint-s3/pull/1205))

## v0.10.0 (October 17, 2024)

Expand Down
86 changes: 64 additions & 22 deletions mountpoint-s3-crt/src/http/request_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,38 @@ unsafe impl Send for Headers {}
/// allow threads to simultaneously modify it.
unsafe impl Sync for Headers {}

/// Errors returned by operations on [Headers]
/// Errors returned by operations on [Headers].
///
/// TODO: Where the variant contains an [OsString] for the header name,
/// we could explore using a static [OsStr] to avoid unnecessary memory copies
/// since we know the values at compilation time.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum HeadersError {
/// The header was not found
#[error("Header not found")]
HeaderNotFound,
#[error("Header {0:?} not found")]
HeaderNotFound(OsString),

/// Internal CRT error
#[error("CRT error: {0}")]
CrtError(#[source] Error),

/// Header value could not be converted to String
#[error("Header string was not valid: {0:?}")]
Invalid(OsString),
#[error("Header {name:?} had invalid string value: {value:?}")]
Invalid {
/// Name of the header
name: OsString,
/// Value of the header, which was not valid to convert to [String]
value: OsString,
},
}

// Convert CRT error into HeadersError, mapping the HEADER_NOT_FOUND to HeadersError::HeaderNotFound.
impl From<Error> for HeadersError {
fn from(err: Error) -> Self {
if err == (aws_http_errors::AWS_ERROR_HTTP_HEADER_NOT_FOUND as i32).into() {
Self::HeaderNotFound
impl HeadersError {
/// Try to convert the CRT [Error] into [HeadersError::HeaderNotFound], or return [HeadersError::CrtError].
fn try_convert(err: Error, header_name: &OsStr) -> HeadersError {
if err.raw_error() == (aws_http_errors::AWS_ERROR_HTTP_HEADER_NOT_FOUND as i32) {
HeadersError::HeaderNotFound(header_name.to_owned())
dannycjones marked this conversation as resolved.
Show resolved Hide resolved
} else {
Self::CrtError(err)
HeadersError::CrtError(err)
}
}
}
Expand All @@ -105,7 +114,11 @@ impl Headers {
/// Create a new [Headers] object in the given allocator.
pub fn new(allocator: &Allocator) -> Result<Self, HeadersError> {
// SAFETY: allocator is a valid aws_allocator, and we check the return is non-null.
let inner = unsafe { aws_http_headers_new(allocator.inner.as_ptr()).ok_or_last_error()? };
let inner = unsafe {
aws_http_headers_new(allocator.inner.as_ptr())
.ok_or_last_error()
.map_err(HeadersError::CrtError)?
};

Ok(Self { inner })
}
Expand All @@ -118,12 +131,14 @@ impl Headers {
}

/// Get the header at the specified index.
pub fn get_index(&self, index: usize) -> Result<Header<OsString, OsString>, HeadersError> {
fn get_index(&self, index: usize) -> Result<Header<OsString, OsString>, HeadersError> {
// SAFETY: `self.inner` is a valid aws_http_headers, and `aws_http_headers_get_index`
// promises to initialize the output `struct aws_http_header *out_header` on success.
let header = unsafe {
let mut header: MaybeUninit<aws_http_header> = MaybeUninit::uninit();
aws_http_headers_get_index(self.inner.as_ptr(), index, header.as_mut_ptr()).ok_or_last_error()?;
aws_http_headers_get_index(self.inner.as_ptr(), index, header.as_mut_ptr())
.ok_or_last_error()
.map_err(HeadersError::CrtError)?;
header.assume_init()
};

Expand Down Expand Up @@ -153,7 +168,9 @@ impl Headers {
// SAFETY: `aws_http_headers_add_header` makes a copy of the underlying strings.
// Also, this function takes a mut reference to `self`, since this function modifies the headers.
unsafe {
aws_http_headers_add_header(self.inner.as_ptr(), &header.inner).ok_or_last_error()?;
aws_http_headers_add_header(self.inner.as_ptr(), &header.inner)
.ok_or_last_error()
.map_err(HeadersError::CrtError)?;
}

Ok(())
Expand All @@ -171,7 +188,9 @@ impl Headers {
// SAFETY: `aws_http_headers_erase` doesn't hold on to a copy of the name we pass in, so it's
// okay to call with with an `aws_byte_cursor` that may not outlive this `Headers`.
unsafe {
aws_http_headers_erase(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor()).ok_or_last_error()?;
aws_http_headers_erase(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor())
.ok_or_last_error()
.map_err(|err| HeadersError::try_convert(err, name.as_ref()))?;
}

Ok(())
Expand All @@ -183,12 +202,15 @@ impl Headers {
// initialize the output `struct aws_byte_cursor *out_value` on success.
let value = unsafe {
let mut value: MaybeUninit<aws_byte_cursor> = MaybeUninit::uninit();

aws_http_headers_get(
self.inner.as_ptr(),
name.as_ref().as_aws_byte_cursor(),
value.as_mut_ptr(),
)
.ok_or_last_error()?;
.ok_or_last_error()
.map_err(|err| HeadersError::try_convert(err, name.as_ref()))?;

value.assume_init()
};

Expand All @@ -203,12 +225,17 @@ impl Headers {

/// Get a single header by name as a [String].
pub fn get_as_string<H: AsRef<OsStr>>(&self, name: H) -> Result<String, HeadersError> {
let name = name.as_ref();
let header = self.get(name)?;
let value = header.value();
if let Some(s) = value.to_str() {
Ok(s.to_string())
} else {
Err(HeadersError::Invalid(value.clone()))
let err = HeadersError::Invalid {
name: name.to_owned(),
value: value.clone(),
};
Err(err)
}
}

Expand Down Expand Up @@ -263,7 +290,7 @@ impl Iterator for HeadersIterator<'_> {
let header = self
.headers
.get_index(self.offset)
.expect("HeadersIterator: failed to get next header");
.expect("headers at any offset smaller than original count should always exist given mut access");
passaro marked this conversation as resolved.
Show resolved Hide resolved
self.offset += 1;

Some((header.name, header.value))
Expand Down Expand Up @@ -417,14 +444,29 @@ mod test {
#[test]
fn test_header_not_present() {
let headers = Headers::new(&Allocator::default()).expect("failed to create headers");
assert!(!headers.has_header("a"));

assert!(!headers.has_header("a"), "header should not be present");

let error = headers.get("a").expect_err("should fail because header is not present");
assert_eq!(error, HeadersError::HeaderNotFound, "should fail with HeaderNotFound");
assert_eq!(
error.to_string(),
"Header \"a\" not found",
"header error display should match expected output",
);
if let HeadersError::HeaderNotFound(name) = error {
assert_eq!(name, "a", "header name should match original argument");
} else {
panic!("should fail with HeaderNotFound");
}

let error = headers
.get_as_string("a")
.expect_err("should fail because header is not present");
assert_eq!(error, HeadersError::HeaderNotFound, "should fail with HeaderNotFound");
if let HeadersError::HeaderNotFound(name) = error {
assert_eq!(name, "a", "header name should match original argument");
} else {
panic!("should fail with HeaderNotFound");
}

let header = headers
.get_as_optional_string("a")
Expand Down
Loading