diff --git a/Cargo.toml b/Cargo.toml index 9312284..a3343b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ chrono = { version = "0.4.31", default-features = false, features = [ futures-util = "0.3.28" itertools = "0.11.0" reqwest = { version = "~0.11.20", features = ["blocking", "json"] } +log = "~0.4.20" [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/api.rs b/src/api.rs index 6b9d6d9..58fc787 100644 --- a/src/api.rs +++ b/src/api.rs @@ -3,6 +3,7 @@ // This file may not be copied, modified, or distributed // except according to those terms. +mod auth; mod client; mod endpoint; mod error; diff --git a/src/api/auth.rs b/src/api/auth.rs new file mode 100644 index 0000000..7393ef8 --- /dev/null +++ b/src/api/auth.rs @@ -0,0 +1,14 @@ +// Licensed under the MIT license +// . +// This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::api::endpoint_prelude::*; +use crate::auth::Auth; + +/// A `auth` modifier that can be applied to any endpoint. +#[derive(Clone)] +pub struct AuthContext { + /// The auth token to use for the endpoint. + token: Auth, +} diff --git a/src/api/eod/eod.rs b/src/api/eod/eod.rs index 3b232a6..b0c1984 100644 --- a/src/api/eod/eod.rs +++ b/src/api/eod/eod.rs @@ -91,13 +91,12 @@ impl<'a> Pageable for Eod<'a> { #[cfg(test)] mod tests { - use std::borrow::BorrowMut; use chrono::NaiveDate; use crate::api::common::SortOrder; use crate::api::eod::Eod; - use crate::api::{self, endpoint_prelude, Query}; + use crate::api::{self, Query}; use crate::test::client::{ExpectedUrl, SingleTestClient}; #[test] diff --git a/src/api/paged/all_at_once.rs b/src/api/paged/all_at_once.rs index acc3e66..e7e734e 100644 --- a/src/api/paged/all_at_once.rs +++ b/src/api/paged/all_at_once.rs @@ -56,7 +56,7 @@ where url }; - let mut page_num = 1; + let mut page_num = 0; let per_page = self.pagination.page_limit(); let per_page_str = per_page.to_string(); @@ -142,3 +142,207 @@ where Ok(std::mem::take(&mut locked_results)) } } + +#[cfg(test)] +mod tests { + use http::StatusCode; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + use crate::api::endpoint_prelude::*; + use crate::api::{self, ApiError, AsyncQuery, Pagination, Query}; + use crate::test::client::{ExpectedUrl, PagedTestClient, SingleTestClient}; + + #[derive(Debug, Default)] + struct Dummy { + with_keyset: bool, + } + + impl Endpoint for Dummy { + fn method(&self) -> Method { + Method::GET + } + + fn endpoint(&self) -> Cow<'static, str> { + "paged_dummy".into() + } + } + + impl Pageable for Dummy { + fn use_keyset_pagination(&self) -> bool { + self.with_keyset + } + } + + #[derive(Debug, Deserialize, Serialize)] + struct DummyResult { + value: u8, + } + + #[test] + fn test_marketstack_non_json_response() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("limit", "1000"), ("offset", "0")]) + .build() + .unwrap(); + let client = SingleTestClient::new_raw(endpoint, "not json"); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All).query(&client); + let err = res.unwrap_err(); + if let ApiError::MarketstackService { status, .. } = err { + assert_eq!(status, http::StatusCode::OK); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_marketstack_error_bad_json() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let client = SingleTestClient::new_raw(endpoint, ""); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All).query(&client); + let err = res.unwrap_err(); + if let ApiError::MarketstackService { status, .. } = err { + assert_eq!(status, http::StatusCode::NOT_FOUND); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_marketstack_error_detection() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let client = SingleTestClient::new_json( + endpoint, + &json!({ + "message": "dummy error message" + }), + ); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All).query(&client); + let err = res.unwrap_err(); + if let ApiError::Marketstack { msg } = err { + assert_eq!(msg, "dummy error message"); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_marketstack_error_detection_unknown() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("limit", "1000"), ("offset", "0")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let err_obj = json!({ + "bogus": "dummy error message" + }); + let client = SingleTestClient::new_json(endpoint, &err_obj); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All).query(&client); + let err = res.unwrap_err(); + if let ApiError::MarketstackUnrecognized { obj } = err { + assert_eq!(obj, err_obj); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_pagination_limit() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy { with_keyset: false }; + + let res: Vec = api::paged(query, Pagination::Limit(25)) + .query(&client) + .unwrap(); + assert_eq!(res.len(), 25); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } + + #[tokio::test] + async fn test_pagination_limit_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy { with_keyset: false }; + + let res: Vec = api::paged(query, Pagination::Limit(25)) + .query_async(&client) + .await + .unwrap(); + assert_eq!(res.len(), 25); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } + + #[test] + fn test_pagination_all() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy::default(); + + let res: Vec = api::paged(query, Pagination::All).query(&client).unwrap(); + assert_eq!(res.len(), 256); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } + + #[tokio::test] + async fn test_pagination_all_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy::default(); + + let res: Vec = api::paged(query, Pagination::All) + .query_async(&client) + .await + .unwrap(); + assert_eq!(res.len(), 256); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } +} diff --git a/src/api/paged/lazy.rs b/src/api/paged/lazy.rs index 480cc1c..04d7e10 100644 --- a/src/api/paged/lazy.rs +++ b/src/api/paged/lazy.rs @@ -122,7 +122,7 @@ where let next_page = if paged.endpoint.use_keyset_pagination() { Page::Keyset(KeysetPage::First) } else { - Page::Number(1) + Page::Number(0) }; let page_state = PageState { @@ -175,12 +175,12 @@ where let mut url = client.rest_endpoint(&self.paged.endpoint.endpoint())?; self.paged.endpoint.parameters().add_to_url(&mut url); - let per_page = self.paged.pagination.page_limit(); - let per_page_str = per_page.to_string(); + let limit = self.paged.pagination.page_limit(); + let limit_str = limit.to_string(); { let mut pairs = url.query_pairs_mut(); - pairs.append_pair("limit", &per_page_str); + pairs.append_pair("limit", &limit_str); next_page.apply_to(&mut pairs); } @@ -353,3 +353,323 @@ where self.current_page.pop().map(Ok) } } + +#[cfg(test)] +mod tests { + use core::panic; + + use futures_util::TryStreamExt; + use http::StatusCode; + use serde::{Deserialize, Serialize}; + use serde_json::json; + + use crate::api::endpoint_prelude::*; + use crate::api::{self, ApiError, Pagination}; + use crate::test::client::{ExpectedUrl, PagedTestClient, SingleTestClient}; + + #[derive(Debug, Default)] + struct Dummy { + with_keyset: bool, + } + + impl Endpoint for Dummy { + fn method(&self) -> Method { + Method::GET + } + + fn endpoint(&self) -> Cow<'static, str> { + "paged_dummy".into() + } + } + + impl Pageable for Dummy { + fn use_keyset_pagination(&self) -> bool { + self.with_keyset + } + } + + #[derive(Debug, Deserialize, Serialize)] + struct DummyResult { + value: u8, + } + + #[test] + fn test_marketstack_non_json_response() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .build() + .unwrap(); + let client = SingleTestClient::new_raw(endpoint, "not json"); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter(&client) + .collect(); + let err = res.unwrap_err(); + if let ApiError::MarketstackService { status, .. } = err { + assert_eq!(status, http::StatusCode::OK); + } else { + panic!("unexpected error: {}", err); + } + } + + #[tokio::test] + async fn test_marketstack_non_json_response_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .build() + .unwrap(); + let client = SingleTestClient::new_raw(endpoint, "not json"); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter_async(&client) + .try_collect() + .await; + let err = res.unwrap_err(); + if let ApiError::MarketstackService { status, .. } = err { + assert_eq!(status, http::StatusCode::OK); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_marketstack_error_bad_json() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let client = SingleTestClient::new_raw(endpoint, ""); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter(&client) + .collect(); + let err = res.unwrap_err(); + if let ApiError::MarketstackService { status, .. } = err { + assert_eq!(status, http::StatusCode::NOT_FOUND); + } else { + panic!("unexpected error: {}", err); + } + } + + #[tokio::test] + async fn test_marketstack_error_bad_json_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let client = SingleTestClient::new_raw(endpoint, ""); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter_async(&client) + .try_collect() + .await; + let err = res.unwrap_err(); + if let ApiError::MarketstackService { status, .. } = err { + assert_eq!(status, http::StatusCode::NOT_FOUND); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_marketstack_error_detection() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let client = SingleTestClient::new_json( + endpoint, + &json!({ + "message": "dummy error message", + }), + ); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter(&client) + .collect(); + let err = res.unwrap_err(); + if let ApiError::Marketstack { msg } = err { + assert_eq!(msg, "dummy error message"); + } else { + panic!("unexpected error: {}", err); + } + } + + #[tokio::test] + async fn test_marketstack_error_detection_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let client = SingleTestClient::new_json( + endpoint, + &json!({ + "message": "dummy error message", + }), + ); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter_async(&client) + .try_collect() + .await; + let err = res.unwrap_err(); + if let ApiError::Marketstack { msg } = err { + assert_eq!(msg, "dummy error message"); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_markestack_error_detection_unknown() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let err_obj = json!({ + "bogus": "dummy error message", + }); + let client = SingleTestClient::new_json(endpoint, &err_obj); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter(&client) + .collect(); + let err = res.unwrap_err(); + if let ApiError::MarketstackUnrecognized { obj } = err { + assert_eq!(obj, err_obj); + } else { + panic!("unexpected error: {}", err); + } + } + + #[tokio::test] + async fn test_marketstack_error_detection_unknown_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .add_query_params(&[("offset", "0"), ("limit", "1000")]) + .status(StatusCode::NOT_FOUND) + .build() + .unwrap(); + let err_obj = json!({ + "bogus": "dummy error message", + }); + let client = SingleTestClient::new_json(endpoint, &err_obj); + let endpoint = Dummy::default(); + + let res: Result, _> = api::paged(endpoint, Pagination::All) + .iter_async(&client) + .try_collect() + .await; + let err = res.unwrap_err(); + if let ApiError::MarketstackUnrecognized { obj } = err { + assert_eq!(obj, err_obj); + } else { + panic!("unexpected error: {}", err); + } + } + + #[test] + fn test_pagination_limit() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy { with_keyset: false }; + + let res: Vec = api::paged(query, Pagination::Limit(25)) + .iter(&client) + .collect::, _>>() + .unwrap(); + assert_eq!(res.len(), 25); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } + + #[tokio::test] + async fn test_pagination_limit_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy { with_keyset: false }; + + let res: Vec = api::paged(query, Pagination::Limit(25)) + .iter_async(&client) + .try_collect() + .await + .unwrap(); + assert_eq!(res.len(), 25); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } + + #[test] + fn test_pagination_all() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy::default(); + + let res: Vec = api::paged(query, Pagination::All) + .iter(&client) + .collect::, _>>() + .unwrap(); + assert_eq!(res.len(), 256); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } + + #[tokio::test] + async fn test_pagination_all_async() { + let endpoint = ExpectedUrl::builder() + .endpoint("paged_dummy") + .paginated(true) + .build() + .unwrap(); + let client = + PagedTestClient::new_raw(endpoint, (0..=255).map(|value| DummyResult { value })); + let query = Dummy::default(); + + let res: Vec = api::paged(query, Pagination::All) + .iter_async(&client) + .try_collect() + .await + .unwrap(); + assert_eq!(res.len(), 256); + for (i, value) in res.iter().enumerate() { + assert_eq!(value.value, i as u8); + } + } +} diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..b633ec5 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,13 @@ +// Licensed under the MIT license +// . +// This file may not be copied, modified, or distributed +// except according to those terms. + +/// A Marketstack API token. +/// +/// Marketstack only supports one kind of token. +#[derive(Clone)] +pub enum Auth { + /// A personal access token, obtained through Marketstack dashboard. + Token(String), +} diff --git a/src/lib.rs b/src/lib.rs index acfe915..b38d7a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod api; +mod auth; #[cfg(test)] mod test; diff --git a/src/test/client.rs b/src/test/client.rs index 93ced19..c3556a2 100644 --- a/src/test/client.rs +++ b/src/test/client.rs @@ -73,6 +73,7 @@ impl ExpectedUrl { continue; } + println!("{:?}", self.query.iter()); let found = self.query.iter().any(|(expected_key, expected_value)| { key == expected_key && value == expected_value }); @@ -220,8 +221,9 @@ impl Page { fn range(self) -> Range { match self { Page::ByNumber { number, size } => { - assert_ne!(number, 0); - let start = size * (number) - 1; + // Marketstack offset defaults to 0 + assert_eq!(number, 0); + let start = size * (number); start..start + size } } @@ -233,6 +235,7 @@ pub struct PagedTestClient { data: Vec, } +const KEYSET_QUERY_PARAM: &str = "__test_keyset"; const DEFAULT_PAGE_SIZE: usize = 20; impl PagedTestClient { @@ -307,7 +310,7 @@ where } let offset = Page::ByNumber { - number: offset.unwrap_or(1), + number: offset.unwrap_or(0), size: limit, }; let range = { @@ -321,6 +324,29 @@ where assert_eq!(*request.method(), Method::GET); let response = Response::builder().status(self.expected.status); + let response = if pagination { + if range.end + 1 < self.data.len() { + // Generate the URL for the next page. + let next_url = { + let mut next_url = url.clone(); + next_url + .query_pairs_mut() + .clear() + .extend_pairs( + url.query_pairs() + .filter(|(key, _)| key != KEYSET_QUERY_PARAM), + ) + .append_pair(KEYSET_QUERY_PARAM, &format!("{}", range.end)); + next_url + }; + let next_header = format!("<{}>; rel=\"next\"", next_url); + response.header(http::header::LINK, next_header) + } else { + response + } + } else { + response + }; let data_page = &self.data[range];