From b74e6452bd9e5c848794a55f8fe864206545ea1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 2 May 2024 11:52:55 +0200 Subject: [PATCH] refactor: Read chunks until n bytes instead of take --- hook-worker/src/error.rs | 11 +++++++++++ hook-worker/src/util.rs | 22 ++++++++++++++++++---- hook-worker/src/worker.rs | 30 ++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/hook-worker/src/error.rs b/hook-worker/src/error.rs index 2458332..1d3da8d 100644 --- a/hook-worker/src/error.rs +++ b/hook-worker/src/error.rs @@ -38,6 +38,17 @@ pub enum WebhookRequestError { }, } +/// Enumeration of errors that can occur while handling a `reqwest::Response`. +/// Currently, not consumed anywhere. Grouped here to support a common error type for +/// `utils::first_n_bytes_of_response`. +#[derive(Error, Debug)] +pub enum WebhookResponseError { + #[error("failed to parse a response as UTF8")] + ParseUTF8StringError(#[from] std::str::Utf8Error), + #[error("error while iterating over response body chunks")] + StreamIterationError(#[from] reqwest::Error), +} + /// Implement display of `WebhookRequestError` by appending to the underlying `reqwest::Error` /// any response message if available. impl fmt::Display for WebhookRequestError { diff --git a/hook-worker/src/util.rs b/hook-worker/src/util.rs index 0a9f372..b1ce2af 100644 --- a/hook-worker/src/util.rs +++ b/hook-worker/src/util.rs @@ -1,16 +1,30 @@ +use crate::error::WebhookResponseError; use futures::StreamExt; use reqwest::Response; pub async fn first_n_bytes_of_response( response: Response, n: usize, -) -> Result { - let mut body = response.bytes_stream().take(n); - let mut buffer = String::new(); +) -> Result { + let mut body = response.bytes_stream(); + let mut buffer = String::with_capacity(n); while let Some(chunk) = body.next().await { + if buffer.len() >= n { + // Early return before reading next chunk. + break; + } + let chunk = chunk?; - buffer.push_str(std::str::from_utf8(&chunk).unwrap()); + let chunk_str = std::str::from_utf8(&chunk)?; + if let Some(partial_chunk_str) = + chunk_str.get(0..std::cmp::min(n - buffer.len(), chunk_str.len())) + { + buffer.push_str(&partial_chunk_str); + } else { + // For whatever reason, we are out of bounds, give up. + break; + } } Ok(buffer) diff --git a/hook-worker/src/worker.rs b/hook-worker/src/worker.rs index bfd7179..14582b2 100644 --- a/hook-worker/src/worker.rs +++ b/hook-worker/src/worker.rs @@ -408,6 +408,7 @@ async fn send_webhook( Err(WebhookError::Request( WebhookRequestError::RetryableRequestError { error: err, + // TODO: Make amount of bytes configurable. response: first_n_bytes_of_response(response, 10 * 1024).await.ok(), retry_after, }, @@ -636,4 +637,33 @@ mod tests { )); } } + + #[sqlx::test(migrations = "../migrations")] + async fn test_error_message_contains_up_to_n_bytes_of_response_body(_pg: PgPool) { + let method = HttpMethod::POST; + let url = "http://localhost:18081/fail"; + let headers = collections::HashMap::new(); + // This is double the current hardcoded amount of bytes. + // TODO: Make this configurable and chage it here too. + let body = (0..20 * 1024).map(|_| "a").collect::>().concat(); + let client = reqwest::Client::new(); + + let err = send_webhook(client, &method, url, &headers, body.to_owned()) + .await + .err() + .expect("request didn't fail when it should have failed"); + + assert!(matches!(err, WebhookError::Request(..))); + if let WebhookError::Request(request_error) = err { + assert_eq!(request_error.status(), Some(StatusCode::BAD_REQUEST)); + assert!(request_error.to_string().contains(&body[0..10 * 1024])); + // The 81 bytes account for the reqwest erorr message as described below. + assert_eq!(request_error.to_string().len(), 10 * 1024 + 81); + // This is the display implementation of reqwest. Just checking it is still there. + // See: https://github.com/seanmonstar/reqwest/blob/master/src/error.rs + assert!(request_error.to_string().contains( + "HTTP status client error (400 Bad Request) for url (http://localhost:18081/fail)" + )); + } + } }