Skip to content

Commit

Permalink
Refactor PineconeError: use thiserror/anyhow, ensure we support…
Browse files Browse the repository at this point in the history
… `Send + Sync` (#56)

## Problem
Currently, `PineconeError` contains a lot of boilerplate for
implementing `std::error::Error`. @haruska suggested we could maybe
simplify some of the boilerplate using the `thiserror` and `anyhow`
crates.

Additionally, there was an enhancement filed
(#54) to
implement `Send` + `Sync` for `PineconeError` as currently we're unable
to use `PineconeError` in a multithreaded context.

## Solution
Refactor PineconeError to use thiserror and anyhow to reduce some of our
boilerplate for the custom error enum, make sure we can safely use
PineconeError with Send and Sync, add unit test for this

## Type of Change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [X]  New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)

## Test Plan
New unit test added to verify that `PineoneError` has properly
implemented the `Send` and `Sync` traits.

`cargo test` -> validate CI passes as expected


---
- To see the specific tasks where the Asana app for GitHub is being
used, see below:
  - https://app.asana.com/0/0/1208161607942725
  - https://app.asana.com/0/0/1208161607942720
  • Loading branch information
austin-denoble authored Sep 4, 2024
1 parent 412a7f0 commit fc241fe
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 136 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ serde = { version = "^1.0", features = ["derive"] }
url = "^2.5"
uuid = { version = "^1.8", features = ["serde", "v4"] }
reqwest = { version = "^0.12", features = ["json", "multipart"] }
thiserror = "1.0.63"
anyhow = "1.0.86"

[dev-dependencies]
temp-env = "0.3"
Expand Down
12 changes: 3 additions & 9 deletions src/pinecone/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,20 +634,14 @@ impl PineconeClient {

// connect to server
let endpoint = Channel::from_shared(host)
.map_err(|e| PineconeError::ConnectionError {
source: Box::new(e),
})?
.map_err(|e| PineconeError::ConnectionError { source: e.into() })?
.tls_config(tls_config)
.map_err(|e| PineconeError::ConnectionError {
source: Box::new(e),
})?;
.map_err(|e| PineconeError::ConnectionError { source: e.into() })?;

let channel = endpoint
.connect()
.await
.map_err(|e| PineconeError::ConnectionError {
source: Box::new(e),
})?;
.map_err(|e| PineconeError::ConnectionError { source: e.into() })?;

// add api key in metadata through interceptor
let token: TonicMetadataVal<_> = self.api_key.parse().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/pinecone/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl PineconeClientConfig {
let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.map_err(|e| PineconeError::ReqwestError { source: e })?;
.map_err(|e| PineconeError::ReqwestError { source: e.into() })?;

let openapi_config = Configuration {
base_path: controller_host.to_string(),
Expand Down
177 changes: 51 additions & 126 deletions src/utils/errors.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::openapi::apis::{Error as OpenApiError, ResponseContent};

use anyhow::Error as AnyhowError;
use reqwest::{self, StatusCode};
use thiserror::Error;

/// PineconeError is the error type for all Pinecone SDK errors.
#[derive(Debug)]
#[derive(Error, Debug)]
pub enum PineconeError {
/// UnknownResponseError: Unknown response error.
#[error("Unknown response error: status: {status}, message: {message}")]
UnknownResponseError {
/// status code
status: StatusCode,
Expand All @@ -14,138 +16,161 @@ pub enum PineconeError {
},

/// ActionForbiddenError: Action is forbidden.
#[error("Action forbidden error: {source}")]
ActionForbiddenError {
/// Source error
source: WrappedResponseContent,
},

/// APIKeyMissingError: API key is not provided as an argument nor in the environment variable `PINECONE_API_KEY`.
#[error("API key missing error: {message}")]
APIKeyMissingError {
/// Error message.
message: String,
},

/// InvalidHeadersError: Provided headers are not valid. Expects JSON.
#[error("Invalid headers error: {message}")]
InvalidHeadersError {
/// Error message.
message: String,
},

/// TimeoutError: Request timed out.
#[error("Timeout error: {message}")]
TimeoutError {
/// Error message.
message: String,
},

/// ConnectionError: Failed to establish a connection.
#[error("Connection error: {source}")]
ConnectionError {
/// inner: Error object for connection error.
source: Box<dyn std::error::Error>,
/// Source of the error.
source: AnyhowError,
},

/// ReqwestError: Error caused by Reqwest
#[error("Reqwest error: {source}")]
ReqwestError {
/// Source error
source: reqwest::Error,
/// Source of the error.
source: AnyhowError,
},

/// SerdeError: Error caused by Serde
#[error("Serde error: {source}")]
SerdeError {
/// Source of the error.
source: serde_json::Error,
source: AnyhowError,
},

/// IoError: Error caused by IO
#[error("IO error: {message}")]
IoError {
/// Error message.
message: String,
},

/// BadRequestError: Bad request. The request body included invalid request parameters
#[error("Bad request error: {source}")]
BadRequestError {
/// Source error
source: WrappedResponseContent,
},

/// UnauthorizedError: Unauthorized. Possibly caused by invalid API key
#[error("Unauthorized error: {source}")]
UnauthorizedError {
/// Source error
source: WrappedResponseContent,
},

/// PodQuotaExceededError: Pod quota exceeded
#[error("Pod quota exceeded error: {source}")]
PodQuotaExceededError {
/// Source error
source: WrappedResponseContent,
},

/// CollectionsQuotaExceededError: Collections quota exceeded
#[error("Collections quota exceeded error: {source}")]
CollectionsQuotaExceededError {
/// Source error
source: WrappedResponseContent,
},

/// InvalidCloudError: Provided cloud is not valid.
#[error("Invalid cloud error: {source}")]
InvalidCloudError {
/// Source error
source: WrappedResponseContent,
},

/// InvalidRegionError: Provided region is not valid.
#[error("Invalid region error: {source}")]
InvalidRegionError {
/// Source error
source: WrappedResponseContent,
},

/// InvalidConfigurationError: Provided configuration is not valid.
#[error("Invalid configuration error: {message}")]
InvalidConfigurationError {
/// Error message.
message: String,
},

/// CollectionNotFoundError: Collection of given name does not exist
#[error("Collection not found error: {source}")]
CollectionNotFoundError {
/// Source error
source: WrappedResponseContent,
},

/// IndexNotFoundError: Index of given name does not exist
#[error("Index not found error: {source}")]
IndexNotFoundError {
/// Source error
source: WrappedResponseContent,
},

/// ResourceAlreadyExistsError: Resource of given name already exists
#[error("Resource already exists error: {source}")]
ResourceAlreadyExistsError {
/// Source error
source: WrappedResponseContent,
},

/// Unprocessable entity error: The request body could not be deserialized
#[error("Unprocessable entity error: {source}")]
UnprocessableEntityError {
/// Source error
source: WrappedResponseContent,
},

/// PendingCollectionError: There is a pending collection created from this index
#[error("Pending collection error: {source}")]
PendingCollectionError {
/// Source error
source: WrappedResponseContent,
},

/// InternalServerError: Internal server error
#[error("Internal server error: {source}")]
InternalServerError {
/// Source error
source: WrappedResponseContent,
},

/// DataPlaneError: Failed to perform a data plane operation.
#[error("Data plane error: {status}")]
DataPlaneError {
/// Error status
status: tonic::Status,
},

/// InferenceError: Failed to perform an inference operation.
#[error("Inference error: {status}")]
InferenceError {
/// Error status
status: tonic::Status,
Expand All @@ -156,8 +181,12 @@ pub enum PineconeError {
impl<T> From<OpenApiError<T>> for PineconeError {
fn from(error: OpenApiError<T>) -> Self {
match error {
OpenApiError::Reqwest(inner) => PineconeError::ReqwestError { source: inner },
OpenApiError::Serde(inner) => PineconeError::SerdeError { source: inner },
OpenApiError::Reqwest(inner) => PineconeError::ReqwestError {
source: inner.into(),
},
OpenApiError::Serde(inner) => PineconeError::SerdeError {
source: inner.into(),
},
OpenApiError::Io(inner) => PineconeError::IoError {
message: inner.to_string(),
},
Expand Down Expand Up @@ -210,123 +239,6 @@ fn parse_forbidden_error(source: WrappedResponseContent, message: String) -> Pin
}
}

impl std::fmt::Display for PineconeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PineconeError::UnknownResponseError { status, message } => {
write!(
f,
"Unknown response error: status: {}, message: {}",
status, message
)
}
PineconeError::ResourceAlreadyExistsError { source } => {
write!(f, "Resource already exists error: {}", source)
}
PineconeError::UnprocessableEntityError { source } => {
write!(f, "Unprocessable entity error: {}", source)
}
PineconeError::PendingCollectionError { source } => {
write!(f, "Pending collection error: {}", source)
}
PineconeError::InternalServerError { source } => {
write!(f, "Internal server error: {}", source)
}
PineconeError::ReqwestError { source } => {
write!(f, "Reqwest error: {}", source.to_string())
}
PineconeError::SerdeError { source } => {
write!(f, "Serde error: {}", source.to_string())
}
PineconeError::IoError { message } => {
write!(f, "IO error: {}", message)
}
PineconeError::BadRequestError { source } => {
write!(f, "Bad request error: {}", source)
}
PineconeError::UnauthorizedError { source } => {
write!(f, "Unauthorized error: status: {}", source)
}
PineconeError::PodQuotaExceededError { source } => {
write!(f, "Pod quota exceeded error: {}", source)
}
PineconeError::CollectionsQuotaExceededError { source } => {
write!(f, "Collections quota exceeded error: {}", source)
}
PineconeError::InvalidCloudError { source } => {
write!(f, "Invalid cloud error: status: {}", source)
}
PineconeError::InvalidRegionError { source } => {
write!(f, "Invalid region error: {}", source)
}
PineconeError::CollectionNotFoundError { source } => {
write!(f, "Collection not found error: {}", source)
}
PineconeError::IndexNotFoundError { source } => {
write!(f, "Index not found error: status: {}", source)
}
PineconeError::APIKeyMissingError { message } => {
write!(f, "API key missing error: {}", message)
}
PineconeError::InvalidHeadersError { message } => {
write!(f, "Invalid headers error: {}", message)
}
PineconeError::TimeoutError { message } => {
write!(f, "Timeout error: {}", message)
}
PineconeError::ConnectionError { source } => {
write!(f, "Connection error: {}", source)
}
PineconeError::DataPlaneError { status } => {
write!(f, "Data plane error: {}", status)
}
PineconeError::InferenceError { status } => {
write!(f, "Inference error: {}", status)
}
PineconeError::ActionForbiddenError { source } => {
write!(f, "Action forbidden error: {}", source)
}
PineconeError::InvalidConfigurationError { message } => {
write!(f, "Invalid configuration error: {}", message)
}
}
}
}

impl std::error::Error for PineconeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PineconeError::UnknownResponseError {
status: _,
message: _,
} => None,
PineconeError::ReqwestError { source } => Some(source),
PineconeError::SerdeError { source } => Some(source),
PineconeError::IoError { message: _ } => None,
PineconeError::BadRequestError { source } => Some(source),
PineconeError::UnauthorizedError { source } => Some(source),
PineconeError::PodQuotaExceededError { source } => Some(source),
PineconeError::CollectionsQuotaExceededError { source } => Some(source),
PineconeError::InvalidCloudError { source } => Some(source),
PineconeError::InvalidRegionError { source } => Some(source),
PineconeError::CollectionNotFoundError { source } => Some(source),
PineconeError::IndexNotFoundError { source } => Some(source),
PineconeError::ResourceAlreadyExistsError { source } => Some(source),
PineconeError::UnprocessableEntityError { source } => Some(source),
PineconeError::PendingCollectionError { source } => Some(source),
PineconeError::InternalServerError { source } => Some(source),
PineconeError::APIKeyMissingError { message: _ } => None,
PineconeError::InvalidHeadersError { message: _ } => None,
PineconeError::TimeoutError { message: _ } => None,
PineconeError::ConnectionError { source } => Some(source.as_ref()),
PineconeError::DataPlaneError { status } => Some(status),
PineconeError::InferenceError { status } => Some(status),
PineconeError::ActionForbiddenError { source } => Some(source),
PineconeError::InvalidConfigurationError { message: _ } => None,
}
}
}

/// WrappedResponseContent is a wrapper around ResponseContent.
#[derive(Debug)]
pub struct WrappedResponseContent {
Expand Down Expand Up @@ -356,3 +268,16 @@ impl std::fmt::Display for WrappedResponseContent {
write!(f, "status: {} content: {}", self.status, self.content)
}
}

#[cfg(test)]
mod tests {
use super::PineconeError;
use tokio;

fn assert_send_sync<T: Send + Sync>() {}

#[tokio::test]
async fn test_pinecone_error_is_send_sync() {
assert_send_sync::<PineconeError>();
}
}

0 comments on commit fc241fe

Please sign in to comment.