diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 1cff8e0..0000000 --- a/rustfmt.toml +++ /dev/null @@ -1,6 +0,0 @@ -# rustfmt configurations from master branch, check rustfmt version -# (cargo fmt --version). -# https://github.com/rust-lang/rustfmt/blob/master/Configurations.md - -edition = "2021" -tab_spaces = 2 \ No newline at end of file diff --git a/src/app.rs b/src/app.rs index 2331f81..665ffca 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,8 +1,8 @@ use axum::http::header; use axum::Router; use tower_http::{ - compression::CompressionLayer, cors::CorsLayer, propagate_header::PropagateHeaderLayer, - sensitive_headers::SetSensitiveHeadersLayer, trace, + compression::CompressionLayer, cors::CorsLayer, propagate_header::PropagateHeaderLayer, + sensitive_headers::SetSensitiveHeadersLayer, trace, }; use crate::logger; @@ -10,39 +10,39 @@ use crate::models; use crate::routes; pub async fn create_app() -> Router { - logger::setup(); + logger::setup(); - models::sync_indexes() - .await - .expect("Failed to sync database indexes"); + models::sync_indexes() + .await + .expect("Failed to sync database indexes"); - Router::new() - .merge(routes::status::create_route()) - .merge(routes::user::create_route()) - .merge(Router::new().nest( - "/v1", - // All public v1 routes will be nested here. - Router::new().merge(routes::cat::create_route()), - )) - // High level logging of requests and responses - .layer( - trace::TraceLayer::new_for_http() - .make_span_with(trace::DefaultMakeSpan::new().include_headers(true)) - .on_request(trace::DefaultOnRequest::new().level(tracing::Level::INFO)) - .on_response(trace::DefaultOnResponse::new().level(tracing::Level::INFO)), - ) - // Mark the `Authorization` request header as sensitive so it doesn't - // show in logs. - .layer(SetSensitiveHeadersLayer::new(std::iter::once( - header::AUTHORIZATION, - ))) - // Compress responses - .layer(CompressionLayer::new()) - // Propagate `X-Request-Id`s from requests to responses - .layer(PropagateHeaderLayer::new(header::HeaderName::from_static( - "x-request-id", - ))) - // CORS configuration. This should probably be more restrictive in - // production. - .layer(CorsLayer::permissive()) + Router::new() + .merge(routes::status::create_route()) + .merge(routes::user::create_route()) + .merge(Router::new().nest( + "/v1", + // All public v1 routes will be nested here. + Router::new().merge(routes::cat::create_route()), + )) + // High level logging of requests and responses + .layer( + trace::TraceLayer::new_for_http() + .make_span_with(trace::DefaultMakeSpan::new().include_headers(true)) + .on_request(trace::DefaultOnRequest::new().level(tracing::Level::INFO)) + .on_response(trace::DefaultOnResponse::new().level(tracing::Level::INFO)), + ) + // Mark the `Authorization` request header as sensitive so it doesn't + // show in logs. + .layer(SetSensitiveHeadersLayer::new(std::iter::once( + header::AUTHORIZATION, + ))) + // Compress responses + .layer(CompressionLayer::new()) + // Propagate `X-Request-Id`s from requests to responses + .layer(PropagateHeaderLayer::new(header::HeaderName::from_static( + "x-request-id", + ))) + // CORS configuration. This should probably be more restrictive in + // production. + .layer(CorsLayer::permissive()) } diff --git a/src/database.rs b/src/database.rs index d6c1d0b..48c5ede 100644 --- a/src/database.rs +++ b/src/database.rs @@ -7,15 +7,15 @@ use crate::settings::SETTINGS; static CONNECTION: OnceCell = OnceCell::const_new(); pub async fn connection() -> &'static Database { - CONNECTION - .get_or_init(|| async { - let db_uri = SETTINGS.database.uri.as_str(); - let db_name = SETTINGS.database.name.as_str(); + CONNECTION + .get_or_init(|| async { + let db_uri = SETTINGS.database.uri.as_str(); + let db_name = SETTINGS.database.name.as_str(); - mongodb::Client::with_uri_str(db_uri) + mongodb::Client::with_uri_str(db_uri) + .await + .expect("Failed to initialize MongoDB connection") + .database(db_name) + }) .await - .expect("Failed to initialize MongoDB connection") - .database(db_name) - }) - .await } diff --git a/src/errors.rs b/src/errors.rs index 63c8e01..3859941 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -11,87 +11,91 @@ use wither::WitherError; #[derive(thiserror::Error, Debug)] #[error("...")] pub enum Error { - #[error("{0}")] - Wither(#[from] WitherError), + #[error("{0}")] + Wither(#[from] WitherError), - #[error("{0}")] - Mongo(#[from] MongoError), + #[error("{0}")] + Mongo(#[from] MongoError), - #[error("Error parsing ObjectID {0}")] - ParseObjectID(String), + #[error("Error parsing ObjectID {0}")] + ParseObjectID(String), - #[error("{0}")] - SerializeMongoResponse(#[from] bson::de::Error), + #[error("{0}")] + SerializeMongoResponse(#[from] bson::de::Error), - #[error("{0}")] - Authenticate(#[from] AuthenticateError), + #[error("{0}")] + Authenticate(#[from] AuthenticateError), - #[error("{0}")] - BadRequest(#[from] BadRequest), + #[error("{0}")] + BadRequest(#[from] BadRequest), - #[error("{0}")] - NotFound(#[from] NotFound), + #[error("{0}")] + NotFound(#[from] NotFound), - #[error("{0}")] - RunSyncTask(#[from] JoinError), + #[error("{0}")] + RunSyncTask(#[from] JoinError), - #[error("{0}")] - HashPassword(#[from] BcryptError), + #[error("{0}")] + HashPassword(#[from] BcryptError), } impl Error { - fn get_codes(&self) -> (StatusCode, u16) { - match *self { - // 4XX Errors - Error::ParseObjectID(_) => (StatusCode::BAD_REQUEST, 40001), - Error::BadRequest(_) => (StatusCode::BAD_REQUEST, 40002), - Error::NotFound(_) => (StatusCode::NOT_FOUND, 40003), - Error::Authenticate(AuthenticateError::WrongCredentials) => (StatusCode::UNAUTHORIZED, 40004), - Error::Authenticate(AuthenticateError::InvalidToken) => (StatusCode::UNAUTHORIZED, 40005), - Error::Authenticate(AuthenticateError::Locked) => (StatusCode::LOCKED, 40006), - - // 5XX Errors - Error::Authenticate(AuthenticateError::TokenCreation) => { - (StatusCode::INTERNAL_SERVER_ERROR, 5001) - } - Error::Wither(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5002), - Error::Mongo(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5003), - Error::SerializeMongoResponse(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5004), - Error::RunSyncTask(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5005), - Error::HashPassword(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5006), + fn get_codes(&self) -> (StatusCode, u16) { + match *self { + // 4XX Errors + Error::ParseObjectID(_) => (StatusCode::BAD_REQUEST, 40001), + Error::BadRequest(_) => (StatusCode::BAD_REQUEST, 40002), + Error::NotFound(_) => (StatusCode::NOT_FOUND, 40003), + Error::Authenticate(AuthenticateError::WrongCredentials) => { + (StatusCode::UNAUTHORIZED, 40004) + } + Error::Authenticate(AuthenticateError::InvalidToken) => { + (StatusCode::UNAUTHORIZED, 40005) + } + Error::Authenticate(AuthenticateError::Locked) => (StatusCode::LOCKED, 40006), + + // 5XX Errors + Error::Authenticate(AuthenticateError::TokenCreation) => { + (StatusCode::INTERNAL_SERVER_ERROR, 5001) + } + Error::Wither(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5002), + Error::Mongo(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5003), + Error::SerializeMongoResponse(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5004), + Error::RunSyncTask(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5005), + Error::HashPassword(_) => (StatusCode::INTERNAL_SERVER_ERROR, 5006), + } } - } - pub fn bad_request() -> Self { - Error::BadRequest(BadRequest {}) - } + pub fn bad_request() -> Self { + Error::BadRequest(BadRequest {}) + } - pub fn not_found() -> Self { - Error::NotFound(NotFound {}) - } + pub fn not_found() -> Self { + Error::NotFound(NotFound {}) + } } impl IntoResponse for Error { - fn into_response(self) -> Response { - let (status_code, code) = self.get_codes(); - let message = self.to_string(); - let body = Json(json!({ "code": code, "message": message })); + fn into_response(self) -> Response { + let (status_code, code) = self.get_codes(); + let message = self.to_string(); + let body = Json(json!({ "code": code, "message": message })); - (status_code, body).into_response() - } + (status_code, body).into_response() + } } #[derive(thiserror::Error, Debug)] #[error("...")] pub enum AuthenticateError { - #[error("Wrong authentication credentials")] - WrongCredentials, - #[error("Failed to create authentication token")] - TokenCreation, - #[error("Invalid authentication credentials")] - InvalidToken, - #[error("User is locked")] - Locked, + #[error("Wrong authentication credentials")] + WrongCredentials, + #[error("Failed to create authentication token")] + TokenCreation, + #[error("Invalid authentication credentials")] + InvalidToken, + #[error("User is locked")] + Locked, } #[derive(thiserror::Error, Debug)] diff --git a/src/logger.rs b/src/logger.rs index 006c19b..ec360d4 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -3,13 +3,13 @@ use std::env; use crate::settings::SETTINGS; pub fn setup() { - if env::var_os("RUST_LOG").is_none() { - let app_name = env::var("CARGO_PKG_NAME").unwrap(); - let level = SETTINGS.logger.level.as_str(); - let env = format!("{app_name }={level},tower_http={level}"); + if env::var_os("RUST_LOG").is_none() { + let app_name = env::var("CARGO_PKG_NAME").unwrap(); + let level = SETTINGS.logger.level.as_str(); + let env = format!("{app_name }={level},tower_http={level}"); - env::set_var("RUST_LOG", env); - } + env::set_var("RUST_LOG", env); + } - tracing_subscriber::fmt::init(); + tracing_subscriber::fmt::init(); } diff --git a/src/main.rs b/src/main.rs index 84b9ae6..42ea311 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,14 +23,14 @@ use settings::SETTINGS; #[tokio::main] async fn main() { - let app = app::create_app().await; + let app = app::create_app().await; - let port = SETTINGS.server.port; - let address = SocketAddr::from(([127, 0, 0, 1], port)); + let port = SETTINGS.server.port; + let address = SocketAddr::from(([127, 0, 0, 1], port)); - info!("Server listening on {}", &address); - axum::Server::bind(&address) - .serve(app.into_make_service()) - .await - .expect("Failed to start server"); + info!("Server listening on {}", &address); + axum::Server::bind(&address) + .serve(app.into_make_service()) + .await + .expect("Failed to start server"); } diff --git a/src/models/cat.rs b/src/models/cat.rs index a654a0c..660f4cd 100644 --- a/src/models/cat.rs +++ b/src/models/cat.rs @@ -14,48 +14,48 @@ impl ModelExt for Cat {} #[derive(Debug, Clone, Serialize, Deserialize, WitherModel, Validate)] #[model(index(keys = r#"doc!{ "user": 1, "created_at": 1 }"#))] pub struct Cat { - #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] - pub id: Option, - pub user: ObjectId, - pub name: String, - pub updated_at: Date, - pub created_at: Date, + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + pub user: ObjectId, + pub name: String, + pub updated_at: Date, + pub created_at: Date, } impl Cat { - pub fn new(user: ObjectId, name: String) -> Self { - let now = date::now(); - Self { - id: None, - user, - name, - updated_at: now, - created_at: now, + pub fn new(user: ObjectId, name: String) -> Self { + let now = date::now(); + Self { + id: None, + user, + name, + updated_at: now, + created_at: now, + } } - } } #[derive(Debug, Serialize, Deserialize)] pub struct PublicCat { - #[serde(alias = "_id", serialize_with = "serialize_object_id_as_hex_string")] - pub id: ObjectId, - #[serde(serialize_with = "serialize_object_id_as_hex_string")] - pub user: ObjectId, - pub name: String, - #[serde(with = "bson_datetime_as_rfc3339_string")] - pub updated_at: Date, - #[serde(with = "bson_datetime_as_rfc3339_string")] - pub created_at: Date, + #[serde(alias = "_id", serialize_with = "serialize_object_id_as_hex_string")] + pub id: ObjectId, + #[serde(serialize_with = "serialize_object_id_as_hex_string")] + pub user: ObjectId, + pub name: String, + #[serde(with = "bson_datetime_as_rfc3339_string")] + pub updated_at: Date, + #[serde(with = "bson_datetime_as_rfc3339_string")] + pub created_at: Date, } impl From for PublicCat { - fn from(cat: Cat) -> Self { - Self { - id: cat.id.unwrap(), - user: cat.user, - name: cat.name.clone(), - updated_at: cat.updated_at, - created_at: cat.created_at, + fn from(cat: Cat) -> Self { + Self { + id: cat.id.unwrap(), + user: cat.user, + name: cat.name.clone(), + updated_at: cat.updated_at, + created_at: cat.created_at, + } } - } } diff --git a/src/models/mod.rs b/src/models/mod.rs index 460bf22..0ace409 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -5,8 +5,8 @@ use crate::utils::models::ModelExt; use crate::Error; pub async fn sync_indexes() -> Result<(), Error> { - user::User::sync_indexes().await?; - cat::Cat::sync_indexes().await?; + user::User::sync_indexes().await?; + cat::Cat::sync_indexes().await?; - Ok(()) + Ok(()) } diff --git a/src/models/user.rs b/src/models/user.rs index be9654f..dba8baa 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -16,78 +16,78 @@ impl ModelExt for User {} #[derive(Debug, Clone, Serialize, Deserialize, WitherModel, Validate)] #[model(index(keys = r#"doc!{ "email": 1 }"#, options = r#"doc!{ "unique": true }"#))] pub struct User { - #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] - pub id: Option, - #[validate(length(min = 1))] - pub name: String, - #[validate(email)] - pub email: String, - pub password: String, - pub updated_at: Date, - pub created_at: Date, - pub locked_at: Option, + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + #[validate(length(min = 1))] + pub name: String, + #[validate(email)] + pub email: String, + pub password: String, + pub updated_at: Date, + pub created_at: Date, + pub locked_at: Option, } impl User { - pub fn new(name: A, email: B, password_hash: C) -> Self - where - A: Into, - B: Into, - C: Into, - { - let now = date::now(); - Self { - id: None, - name: name.into(), - email: email.into(), - password: password_hash.into(), - updated_at: now, - created_at: now, - locked_at: None, + pub fn new(name: A, email: B, password_hash: C) -> Self + where + A: Into, + B: Into, + C: Into, + { + let now = date::now(); + Self { + id: None, + name: name.into(), + email: email.into(), + password: password_hash.into(), + updated_at: now, + created_at: now, + locked_at: None, + } } - } - pub fn is_password_match(&self, password: &str) -> bool { - bcrypt::verify(password, self.password.as_ref()).unwrap_or(false) - } + pub fn is_password_match(&self, password: &str) -> bool { + bcrypt::verify(password, self.password.as_ref()).unwrap_or(false) + } } #[derive(Debug, Serialize, Deserialize)] pub struct PublicUser { - #[serde(alias = "_id", serialize_with = "serialize_object_id_as_hex_string")] - pub id: ObjectId, - pub name: String, - pub email: String, - #[serde(with = "bson_datetime_as_rfc3339_string")] - pub updated_at: Date, - #[serde(with = "bson_datetime_as_rfc3339_string")] - pub created_at: Date, + #[serde(alias = "_id", serialize_with = "serialize_object_id_as_hex_string")] + pub id: ObjectId, + pub name: String, + pub email: String, + #[serde(with = "bson_datetime_as_rfc3339_string")] + pub updated_at: Date, + #[serde(with = "bson_datetime_as_rfc3339_string")] + pub created_at: Date, } impl From for PublicUser { - fn from(user: User) -> Self { - Self { - id: user.id.unwrap(), - name: user.name.clone(), - email: user.email.clone(), - updated_at: user.updated_at, - created_at: user.created_at, + fn from(user: User) -> Self { + Self { + id: user.id.unwrap(), + name: user.name.clone(), + email: user.email.clone(), + updated_at: user.updated_at, + created_at: user.created_at, + } } - } } pub async fn hash_password

(password: P) -> Result where - P: AsRef + Send + 'static, + P: AsRef + Send + 'static, { - // TODO: Hash password with salt. - // https://docs.rs/bcrypt/latest/bcrypt/fn.hash_with_salt.html - #[cfg(not(test))] - let cost = bcrypt::DEFAULT_COST; - #[cfg(test)] - let cost = 4; - task::spawn_blocking(move || bcrypt::hash(password.as_ref(), cost)) - .await - .map_err(Error::RunSyncTask)? - .map_err(Error::HashPassword) + // TODO: Hash password with salt. + // https://docs.rs/bcrypt/latest/bcrypt/fn.hash_with_salt.html + #[cfg(not(test))] + let cost = bcrypt::DEFAULT_COST; + #[cfg(test)] + let cost = 4; + task::spawn_blocking(move || bcrypt::hash(password.as_ref(), cost)) + .await + .map_err(Error::RunSyncTask)? + .map_err(Error::HashPassword) } diff --git a/src/routes/cat.rs b/src/routes/cat.rs index 3f65fe9..c9b6cf1 100644 --- a/src/routes/cat.rs +++ b/src/routes/cat.rs @@ -1,8 +1,8 @@ use axum::http::StatusCode; use axum::{ - extract::Path, - routing::{delete, get, post, put}, - Json, Router, + extract::Path, + routing::{delete, get, post, put}, + Json, Router, }; use bson::doc; use serde::{Deserialize, Serialize}; @@ -19,120 +19,120 @@ use crate::utils::to_object_id::to_object_id; use crate::utils::token::TokenUser; pub fn create_route() -> Router { - Router::new() - .route("/cats", post(create_cat)) - .route("/cats", get(query_cats)) - .route("/cats/:id", get(get_cat_by_id)) - .route("/cats/:id", delete(remove_cat_by_id)) - .route("/cats/:id", put(update_cat_by_id)) + Router::new() + .route("/cats", post(create_cat)) + .route("/cats", get(query_cats)) + .route("/cats/:id", get(get_cat_by_id)) + .route("/cats/:id", delete(remove_cat_by_id)) + .route("/cats/:id", put(update_cat_by_id)) } async fn create_cat(user: TokenUser, Json(payload): Json) -> Response { - let cat = Cat::new(user.id, payload.name); - let cat = Cat::create(cat).await?; - let res = PublicCat::from(cat); + let cat = Cat::new(user.id, payload.name); + let cat = Cat::create(cat).await?; + let res = PublicCat::from(cat); - let res = CustomResponseBuilder::new() - .body(res) - .status_code(StatusCode::CREATED) - .build(); + let res = CustomResponseBuilder::new() + .body(res) + .status_code(StatusCode::CREATED) + .build(); - Ok(res) + Ok(res) } async fn query_cats(user: TokenUser, pagination: Pagination) -> Response> { - let options = FindOptions::builder() - .sort(doc! { "created_at": -1_i32 }) - .skip(pagination.offset) - .limit(pagination.limit as i64) - .build(); - - let (cats, count) = Cat::find_and_count(doc! { "user": &user.id }, options).await?; - let cats = cats.into_iter().map(Into::into).collect::>(); - - let res = CustomResponseBuilder::new() - .body(cats) - .pagination(ResponsePagination { - count, - offset: pagination.offset, - limit: pagination.limit, - }) - .build(); - - debug!("Returning cats"); - Ok(res) + let options = FindOptions::builder() + .sort(doc! { "created_at": -1_i32 }) + .skip(pagination.offset) + .limit(pagination.limit as i64) + .build(); + + let (cats, count) = Cat::find_and_count(doc! { "user": &user.id }, options).await?; + let cats = cats.into_iter().map(Into::into).collect::>(); + + let res = CustomResponseBuilder::new() + .body(cats) + .pagination(ResponsePagination { + count, + offset: pagination.offset, + limit: pagination.limit, + }) + .build(); + + debug!("Returning cats"); + Ok(res) } async fn get_cat_by_id(user: TokenUser, Path(id): Path) -> Result, Error> { - let cat_id = to_object_id(id)?; - let cat = Cat::find_one(doc! { "_id": cat_id, "user": &user.id }, None) - .await? - .map(PublicCat::from); - - let cat = match cat { - Some(cat) => cat, - None => { - debug!("Cat not found, returning 404 status code"); - return Err(Error::not_found()); - } - }; - - debug!("Returning cat"); - Ok(Json(cat)) + let cat_id = to_object_id(id)?; + let cat = Cat::find_one(doc! { "_id": cat_id, "user": &user.id }, None) + .await? + .map(PublicCat::from); + + let cat = match cat { + Some(cat) => cat, + None => { + debug!("Cat not found, returning 404 status code"); + return Err(Error::not_found()); + } + }; + + debug!("Returning cat"); + Ok(Json(cat)) } async fn remove_cat_by_id( - user: TokenUser, - Path(id): Path, + user: TokenUser, + Path(id): Path, ) -> Result, Error> { - let cat_id = to_object_id(id)?; - let delete_result = Cat::delete_one(doc! { "_id": cat_id, "user": &user.id }).await?; + let cat_id = to_object_id(id)?; + let delete_result = Cat::delete_one(doc! { "_id": cat_id, "user": &user.id }).await?; - if delete_result.deleted_count == 0 { - debug!("Cat not found, returning 404 status code"); - return Err(Error::not_found()); - } + if delete_result.deleted_count == 0 { + debug!("Cat not found, returning 404 status code"); + return Err(Error::not_found()); + } - let res = CustomResponseBuilder::new() - .status_code(StatusCode::NO_CONTENT) - .build(); + let res = CustomResponseBuilder::new() + .status_code(StatusCode::NO_CONTENT) + .build(); - Ok(res) + Ok(res) } async fn update_cat_by_id( - user: TokenUser, - Path(id): Path, - Json(payload): Json, + user: TokenUser, + Path(id): Path, + Json(payload): Json, ) -> Result, Error> { - let cat_id = to_object_id(id)?; - let update = bson::to_document(&payload).unwrap(); - - let cat = Cat::find_one_and_update( - doc! { "_id": &cat_id, "user": &user.id }, - doc! { "$set": update }, - ) - .await? - .map(PublicCat::from); - - let cat = match cat { - Some(cat) => cat, - None => { - debug!("Cat not found, returning 404 status code"); - return Err(Error::not_found()); - } - }; + let cat_id = to_object_id(id)?; + let update = bson::to_document(&payload).unwrap(); + + let cat = Cat::find_one_and_update( + doc! { "_id": &cat_id, "user": &user.id }, + doc! { "$set": update }, + ) + .await? + .map(PublicCat::from); + + let cat = match cat { + Some(cat) => cat, + None => { + debug!("Cat not found, returning 404 status code"); + return Err(Error::not_found()); + } + }; - debug!("Returning cat"); - Ok(Json(cat)) + debug!("Returning cat"); + Ok(Json(cat)) } #[derive(Deserialize)] struct CreateCat { - name: String, + name: String, } #[derive(Serialize, Deserialize)] struct UpdateCat { - name: String, + name: String, } diff --git a/src/routes/status.rs b/src/routes/status.rs index 84ff62d..46b1617 100644 --- a/src/routes/status.rs +++ b/src/routes/status.rs @@ -6,17 +6,17 @@ use tracing::debug; use crate::errors::Error; pub fn create_route() -> Router { - Router::new().route("/status", get(get_status)) + Router::new().route("/status", get(get_status)) } async fn get_status() -> Result, Error> { - debug!("Returning status"); - Ok(Json(Status { - status: "ok".to_owned(), - })) + debug!("Returning status"); + Ok(Json(Status { + status: "ok".to_owned(), + })) } #[derive(Serialize, Deserialize, Debug)] struct Status { - status: String, + status: String, } diff --git a/src/routes/user.rs b/src/routes/user.rs index dce999f..4b74bb3 100644 --- a/src/routes/user.rs +++ b/src/routes/user.rs @@ -13,89 +13,89 @@ use crate::utils::models::ModelExt; use crate::utils::token; pub fn create_route() -> Router { - Router::new() - .route("/users", post(create_user)) - .route("/users/authenticate", post(authenticate_user)) + Router::new() + .route("/users", post(create_user)) + .route("/users/authenticate", post(authenticate_user)) } async fn create_user(Json(body): Json) -> Result, Error> { - let password_hash = user::hash_password(body.password).await?; - let user = User::new(body.name, body.email, password_hash); - let user = User::create(user).await?; - let res = PublicUser::from(user); + let password_hash = user::hash_password(body.password).await?; + let user = User::new(body.name, body.email, password_hash); + let user = User::create(user).await?; + let res = PublicUser::from(user); - let res = CustomResponseBuilder::new() - .body(res) - .status_code(StatusCode::CREATED) - .build(); + let res = CustomResponseBuilder::new() + .body(res) + .status_code(StatusCode::CREATED) + .build(); - Ok(res) + Ok(res) } async fn authenticate_user( - Json(body): Json, + Json(body): Json, ) -> Result, Error> { - let email = &body.email; - let password = &body.password; - - if email.is_empty() { - debug!("Missing email, returning 400 status code"); - return Err(Error::bad_request()); - } - - if password.is_empty() { - debug!("Missing password, returning 400 status code"); - return Err(Error::bad_request()); - } - - let user = User::find_one(doc! { "email": email }, None).await?; - - let user = match user { - Some(user) => user, - None => { - debug!("User not found, returning 401"); - return Err(Error::not_found()); + let email = &body.email; + let password = &body.password; + + if email.is_empty() { + debug!("Missing email, returning 400 status code"); + return Err(Error::bad_request()); + } + + if password.is_empty() { + debug!("Missing password, returning 400 status code"); + return Err(Error::bad_request()); } - }; - if !user.is_password_match(password) { - debug!("User password is incorrect, returning 401 status code"); - return Err(Error::Authenticate(AuthenticateError::WrongCredentials)); - } + let user = User::find_one(doc! { "email": email }, None).await?; + + let user = match user { + Some(user) => user, + None => { + debug!("User not found, returning 401"); + return Err(Error::not_found()); + } + }; - if user.locked_at.is_some() { - debug!("User is locked, returning 401"); - return Err(Error::Authenticate(AuthenticateError::Locked)); - } + if !user.is_password_match(password) { + debug!("User password is incorrect, returning 401 status code"); + return Err(Error::Authenticate(AuthenticateError::WrongCredentials)); + } + + if user.locked_at.is_some() { + debug!("User is locked, returning 401"); + return Err(Error::Authenticate(AuthenticateError::Locked)); + } - let secret = SETTINGS.auth.secret.as_str(); - let token = token::create(user.clone(), secret) - .map_err(|_| Error::Authenticate(AuthenticateError::TokenCreation))?; + let secret = SETTINGS.auth.secret.as_str(); + let token = token::create(user.clone(), secret) + .map_err(|_| Error::Authenticate(AuthenticateError::TokenCreation))?; - let res = AuthenticateResponse { - access_token: token, - user: PublicUser::from(user), - }; + let res = AuthenticateResponse { + access_token: token, + user: PublicUser::from(user), + }; - Ok(Json(res)) + Ok(Json(res)) } // TODO: Validate password length #[derive(Debug, Deserialize)] struct CreateBody { - name: String, - email: String, - password: String, + name: String, + email: String, + password: String, } #[derive(Debug, Deserialize)] struct AuthorizeBody { - email: String, - password: String, + email: String, + password: String, } #[derive(Debug, Serialize, Deserialize)] pub struct AuthenticateResponse { - pub access_token: String, - pub user: PublicUser, + pub access_token: String, + pub user: PublicUser, } diff --git a/src/settings.rs b/src/settings.rs index 88d3e1a..fe6f44d 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -4,63 +4,66 @@ use serde::Deserialize; use std::{env, fmt}; pub static SETTINGS: Lazy = - Lazy::new(|| Settings::new().expect("Failed to setup settings")); + Lazy::new(|| Settings::new().expect("Failed to setup settings")); #[derive(Debug, Clone, Deserialize)] pub struct Server { - pub port: u16, + pub port: u16, } #[derive(Debug, Clone, Deserialize)] pub struct Logger { - pub level: String, + pub level: String, } #[derive(Debug, Clone, Deserialize)] pub struct Database { - pub uri: String, - pub name: String, + pub uri: String, + pub name: String, } #[derive(Debug, Clone, Deserialize)] pub struct Auth { - pub secret: String, + pub secret: String, } +// Remove the #[allow(dead_code)] attribute from the Settings struct when all the fields are being +// used. +#[allow(dead_code)] #[derive(Debug, Clone, Deserialize)] pub struct Settings { - pub environment: String, - pub server: Server, - pub logger: Logger, - pub database: Database, - pub auth: Auth, + pub environment: String, + pub server: Server, + pub logger: Logger, + pub database: Database, + pub auth: Auth, } impl Settings { - pub fn new() -> Result { - let run_mode = env::var("RUN_MODE").unwrap_or_else(|_| "development".into()); + pub fn new() -> Result { + let run_mode = env::var("RUN_MODE").unwrap_or_else(|_| "development".into()); - let mut builder = Config::builder() - .add_source(File::with_name("config/default")) - .add_source(File::with_name(&format!("config/{run_mode}")).required(false)) - .add_source(File::with_name("config/local").required(false)) - .add_source(Environment::default().separator("__")); + let mut builder = Config::builder() + .add_source(File::with_name("config/default")) + .add_source(File::with_name(&format!("config/{run_mode}")).required(false)) + .add_source(File::with_name("config/local").required(false)) + .add_source(Environment::default().separator("__")); - // Some cloud services like Heroku exposes a randomly assigned port in - // the PORT env var and there is no way to change the env var name. - if let Ok(port) = env::var("PORT") { - builder = builder.set_override("server.port", port)?; - } + // Some cloud services like Heroku exposes a randomly assigned port in + // the PORT env var and there is no way to change the env var name. + if let Ok(port) = env::var("PORT") { + builder = builder.set_override("server.port", port)?; + } - builder - .build()? - // Deserialize (and thus freeze) the entire configuration. - .try_deserialize() - } + builder + .build()? + // Deserialize (and thus freeze) the entire configuration. + .try_deserialize() + } } impl fmt::Display for Server { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "http://localhost:{}", &self.port) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "http://localhost:{}", &self.port) + } } diff --git a/src/tests/routes/cat.rs b/src/tests/routes/cat.rs index a62a590..027b6bd 100644 --- a/src/tests/routes/cat.rs +++ b/src/tests/routes/cat.rs @@ -14,152 +14,152 @@ use pretty_assertions::assert_eq; #[test] fn post_cat_route() { - #[derive(Debug, Serialize, Deserialize)] - struct Body { - name: String, - } - - let body = Body { - name: "Tigrin".to_owned(), - }; - - use_app(async move { - let user = create_user("nico@test.com").await.unwrap(); - let token = create_user_token(user.clone()).await.unwrap(); - - let client = reqwest::Client::new(); - let res = client - .post("http://localhost:8088/v1/cats") - .header("Authorization", format!("Bearer {}", token)) - .json(&body) - .send() - .await - .unwrap(); - - // Status code: - let status_code = res.status(); - let actual = status_code; - let expected = StatusCode::CREATED; - assert_eq!(actual, expected); - - // Body: - let body = res.json::().await.unwrap(); - assert_eq!(body.name, "Tigrin"); - assert_eq!(body.user, user.id.unwrap(), "Cat should belong to user"); - }); + #[derive(Debug, Serialize, Deserialize)] + struct Body { + name: String, + } + + let body = Body { + name: "Tigrin".to_owned(), + }; + + use_app(async move { + let user = create_user("nico@test.com").await.unwrap(); + let token = create_user_token(user.clone()).await.unwrap(); + + let client = reqwest::Client::new(); + let res = client + .post("http://localhost:8088/v1/cats") + .header("Authorization", format!("Bearer {}", token)) + .json(&body) + .send() + .await + .unwrap(); + + // Status code: + let status_code = res.status(); + let actual = status_code; + let expected = StatusCode::CREATED; + assert_eq!(actual, expected); + + // Body: + let body = res.json::().await.unwrap(); + assert_eq!(body.name, "Tigrin"); + assert_eq!(body.user, user.id.unwrap(), "Cat should belong to user"); + }); } #[test] fn get_cats_route() { - use_app(async move { - let user = create_user("nico@test.com").await.unwrap(); - let token = create_user_token(user.clone()).await.unwrap(); - - let tigrin = Cat::new(user.id.unwrap(), "Tigrin".to_owned()); - Cat::create(tigrin).await.unwrap(); - - let cielito = Cat::new(user.id.unwrap(), "Cielito".to_owned()); - Cat::create(cielito).await.unwrap(); - - let client = reqwest::Client::new(); - let res = client - .get("http://localhost:8088/v1/cats") - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .unwrap(); - - // Status code: - let status_code = res.status(); - let actual = status_code; - let expected = StatusCode::OK; - assert_eq!(actual, expected); - - // Response headers: - let headers = res.headers(); - assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); - // Response pagination headers: - assert_eq!(headers.get("X-Pagination-Count").unwrap(), "2"); - assert_eq!(headers.get("X-Pagination-Offset").unwrap(), "0"); - assert_eq!(headers.get("X-Pagination-Limit").unwrap(), "100"); - - // Body: - let body = res.json::>().await.unwrap(); - assert_eq!(body.len(), 2, "Should return two cats"); - - // First cat (Cielito): - let cat = body.get(0).unwrap(); - assert_eq!(cat.name, "Cielito"); - assert_eq!(cat.user, user.id.unwrap()); - - // Second cat (Tigrin): - let cat = body.get(1).unwrap(); - assert_eq!(cat.name, "Tigrin"); - assert_eq!(cat.user, user.id.unwrap()); - }); + use_app(async move { + let user = create_user("nico@test.com").await.unwrap(); + let token = create_user_token(user.clone()).await.unwrap(); + + let tigrin = Cat::new(user.id.unwrap(), "Tigrin".to_owned()); + Cat::create(tigrin).await.unwrap(); + + let cielito = Cat::new(user.id.unwrap(), "Cielito".to_owned()); + Cat::create(cielito).await.unwrap(); + + let client = reqwest::Client::new(); + let res = client + .get("http://localhost:8088/v1/cats") + .header("Authorization", format!("Bearer {}", token)) + .send() + .await + .unwrap(); + + // Status code: + let status_code = res.status(); + let actual = status_code; + let expected = StatusCode::OK; + assert_eq!(actual, expected); + + // Response headers: + let headers = res.headers(); + assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); + // Response pagination headers: + assert_eq!(headers.get("X-Pagination-Count").unwrap(), "2"); + assert_eq!(headers.get("X-Pagination-Offset").unwrap(), "0"); + assert_eq!(headers.get("X-Pagination-Limit").unwrap(), "100"); + + // Body: + let body = res.json::>().await.unwrap(); + assert_eq!(body.len(), 2, "Should return two cats"); + + // First cat (Cielito): + let cat = body.get(0).unwrap(); + assert_eq!(cat.name, "Cielito"); + assert_eq!(cat.user, user.id.unwrap()); + + // Second cat (Tigrin): + let cat = body.get(1).unwrap(); + assert_eq!(cat.name, "Tigrin"); + assert_eq!(cat.user, user.id.unwrap()); + }); } #[test] fn get_cat_by_id_route() { - use_app(async move { - let user = create_user("nico@test.com").await.unwrap(); - let token = create_user_token(user.clone()).await.unwrap(); - - let cholin = Cat::new(user.id.unwrap(), "Cholin".to_owned()); - let cholin = Cat::create(cholin).await.unwrap(); - - let client = reqwest::Client::new(); - let res = client - .get(format!( - "http://localhost:8088/v1/cats/{}", - cholin.id.unwrap() - )) - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .unwrap(); - - // Status code: - let status_code = res.status(); - let actual = status_code; - let expected = StatusCode::OK; - assert_eq!(actual, expected); - - // Body: - let body = res.json::().await.unwrap(); - assert_eq!(body.name, "Cholin"); - assert_eq!(body.user, user.id.unwrap()); - }); + use_app(async move { + let user = create_user("nico@test.com").await.unwrap(); + let token = create_user_token(user.clone()).await.unwrap(); + + let cholin = Cat::new(user.id.unwrap(), "Cholin".to_owned()); + let cholin = Cat::create(cholin).await.unwrap(); + + let client = reqwest::Client::new(); + let res = client + .get(format!( + "http://localhost:8088/v1/cats/{}", + cholin.id.unwrap() + )) + .header("Authorization", format!("Bearer {}", token)) + .send() + .await + .unwrap(); + + // Status code: + let status_code = res.status(); + let actual = status_code; + let expected = StatusCode::OK; + assert_eq!(actual, expected); + + // Body: + let body = res.json::().await.unwrap(); + assert_eq!(body.name, "Cholin"); + assert_eq!(body.user, user.id.unwrap()); + }); } #[test] fn remove_cat_by_id_route() { - use_app(async move { - let user = create_user("nico@test.com").await.unwrap(); - let token = create_user_token(user.clone()).await.unwrap(); - - let tigrin = Cat::new(user.id.unwrap(), "Tigrin".to_owned()); - let tigrin = Cat::create(tigrin).await.unwrap(); - - let client = reqwest::Client::new(); - let res = client - .delete(format!( - "http://localhost:8088/v1/cats/{}", - tigrin.id.unwrap() - )) - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .unwrap(); - - // Status code: - let status_code = res.status(); - let actual = status_code; - let expected = StatusCode::NO_CONTENT; - assert_eq!(actual, expected); - - // Cat from the database - let cat = Cat::find_by_id(&tigrin.id.unwrap()).await.unwrap(); - assert!(cat.is_none(), "Cat should be removed from the database"); - }); + use_app(async move { + let user = create_user("nico@test.com").await.unwrap(); + let token = create_user_token(user.clone()).await.unwrap(); + + let tigrin = Cat::new(user.id.unwrap(), "Tigrin".to_owned()); + let tigrin = Cat::create(tigrin).await.unwrap(); + + let client = reqwest::Client::new(); + let res = client + .delete(format!( + "http://localhost:8088/v1/cats/{}", + tigrin.id.unwrap() + )) + .header("Authorization", format!("Bearer {}", token)) + .send() + .await + .unwrap(); + + // Status code: + let status_code = res.status(); + let actual = status_code; + let expected = StatusCode::NO_CONTENT; + assert_eq!(actual, expected); + + // Cat from the database + let cat = Cat::find_by_id(&tigrin.id.unwrap()).await.unwrap(); + assert!(cat.is_none(), "Cat should be removed from the database"); + }); } diff --git a/src/tests/routes/status.rs b/src/tests/routes/status.rs index 4b14678..345b84d 100644 --- a/src/tests/routes/status.rs +++ b/src/tests/routes/status.rs @@ -11,19 +11,19 @@ use pretty_assertions::assert_eq; #[test] fn get_status_route() { - use_app(async { - let res = reqwest::get("http://localhost:8088/status").await.unwrap(); - let status_code = res.status(); - let body = res.json::().await.unwrap(); + use_app(async { + let res = reqwest::get("http://localhost:8088/status").await.unwrap(); + let status_code = res.status(); + let body = res.json::().await.unwrap(); - // Status code: - let actual = status_code; - let expected = StatusCode::OK; - assert_eq!(actual, expected); + // Status code: + let actual = status_code; + let expected = StatusCode::OK; + assert_eq!(actual, expected); - // Body: - let actual = body; - let expected = json!({ "status": "ok" }); - assert_json_eq!(actual, expected); - }); + // Body: + let actual = body; + let expected = json!({ "status": "ok" }); + assert_json_eq!(actual, expected); + }); } diff --git a/src/tests/routes/user.rs b/src/tests/routes/user.rs index 4b9c93d..2e5a3b7 100644 --- a/src/tests/routes/user.rs +++ b/src/tests/routes/user.rs @@ -12,73 +12,73 @@ use pretty_assertions::assert_eq; #[test] fn post_user_route() { - #[derive(Debug, Serialize, Deserialize)] - struct Body { - name: String, - email: String, - password: String, - } + #[derive(Debug, Serialize, Deserialize)] + struct Body { + name: String, + email: String, + password: String, + } - let body = Body { - name: "Nahuel".to_owned(), - email: "nahuel@gmail.com".to_owned(), - password: "Password1".to_owned(), - }; + let body = Body { + name: "Nahuel".to_owned(), + email: "nahuel@gmail.com".to_owned(), + password: "Password1".to_owned(), + }; - use_app(async move { - let client = reqwest::Client::new(); - let res = client - .post("http://localhost:8088/users") - .json(&body) - .send() - .await - .unwrap(); + use_app(async move { + let client = reqwest::Client::new(); + let res = client + .post("http://localhost:8088/users") + .json(&body) + .send() + .await + .unwrap(); - // Status code: - let status_code = res.status(); - let actual = status_code; - let expected = StatusCode::CREATED; - assert_eq!(actual, expected); + // Status code: + let status_code = res.status(); + let actual = status_code; + let expected = StatusCode::CREATED; + assert_eq!(actual, expected); - // Body: - let body = res.json::().await.unwrap(); - assert_eq!(body.name, "Nahuel"); - assert_eq!(body.email, "nahuel@gmail.com"); - }); + // Body: + let body = res.json::().await.unwrap(); + assert_eq!(body.name, "Nahuel"); + assert_eq!(body.email, "nahuel@gmail.com"); + }); } #[test] fn authenticate_user_route() { - #[derive(Debug, Serialize, Deserialize)] - struct RequestBody { - email: String, - password: String, - } + #[derive(Debug, Serialize, Deserialize)] + struct RequestBody { + email: String, + password: String, + } - let request_body = RequestBody { - email: "nahuel@gmail.com".to_owned(), - password: "Password1".to_owned(), - }; + let request_body = RequestBody { + email: "nahuel@gmail.com".to_owned(), + password: "Password1".to_owned(), + }; - use_app(async move { - create_user("nahuel@gmail.com").await.unwrap(); + use_app(async move { + create_user("nahuel@gmail.com").await.unwrap(); - let client = reqwest::Client::new(); - let res = client - .post("http://localhost:8088/users/authenticate") - .json(&request_body) - .send() - .await - .unwrap(); + let client = reqwest::Client::new(); + let res = client + .post("http://localhost:8088/users/authenticate") + .json(&request_body) + .send() + .await + .unwrap(); - // Status code: - let status_code = res.status(); - let actual = status_code; - let expected = StatusCode::OK; - assert_eq!(actual, expected); + // Status code: + let status_code = res.status(); + let actual = status_code; + let expected = StatusCode::OK; + assert_eq!(actual, expected); - // Body: - let body = res.json::().await.unwrap(); - assert_eq!(body.user.email, "nahuel@gmail.com"); - }); + // Body: + let body = res.json::().await.unwrap(); + assert_eq!(body.user.email, "nahuel@gmail.com"); + }); } diff --git a/src/tests/setup.rs b/src/tests/setup.rs index 6a4b930..ef7592e 100644 --- a/src/tests/setup.rs +++ b/src/tests/setup.rs @@ -14,34 +14,33 @@ static API: OnceCell<()> = OnceCell::const_new(); static RUNTIME: Lazy = Lazy::new(|| Runtime::new().unwrap()); pub async fn start_api_once() { - API - .get_or_init(|| async { - std::env::set_var("RUN_MODE", "test"); - - let app = create_app().await; - let port = SETTINGS.server.port; - let address = SocketAddr::from(([127, 0, 0, 1], port)); - - tokio::spawn(async move { - axum::Server::bind(&address) - .serve(app.into_make_service()) - .await - .expect("Failed to start server"); - }); + API.get_or_init(|| async { + std::env::set_var("RUN_MODE", "test"); + + let app = create_app().await; + let port = SETTINGS.server.port; + let address = SocketAddr::from(([127, 0, 0, 1], port)); + + tokio::spawn(async move { + axum::Server::bind(&address) + .serve(app.into_make_service()) + .await + .expect("Failed to start server"); + }); }) .await; } pub fn use_app(test: F) where - F: std::future::Future, + F: std::future::Future, { - RUNTIME.block_on(async move { - start_api_once().await; + RUNTIME.block_on(async move { + start_api_once().await; - Cat::delete_many(doc! {}).await.unwrap(); - User::delete_many(doc! {}).await.unwrap(); + Cat::delete_many(doc! {}).await.unwrap(); + User::delete_many(doc! {}).await.unwrap(); - test.await; - }) + test.await; + }) } diff --git a/src/tests/utils.rs b/src/tests/utils.rs index 74c4d3e..bac3837 100644 --- a/src/tests/utils.rs +++ b/src/tests/utils.rs @@ -6,19 +6,19 @@ use crate::utils::models::ModelExt; use crate::utils::token; pub async fn create_user>(email: T) -> Result { - let name = "Nahuel"; - let password = "Password1"; + let name = "Nahuel"; + let password = "Password1"; - let password_hash = hash_password(password).await?; - let user = User::new(name, email.as_ref(), password_hash); - let user = User::create(user).await?; + let password_hash = hash_password(password).await?; + let user = User::new(name, email.as_ref(), password_hash); + let user = User::create(user).await?; - Ok(user) + Ok(user) } pub async fn create_user_token(user: User) -> Result { - let secret = SETTINGS.auth.secret.as_str(); - let token = token::create(user, secret).unwrap(); + let secret = SETTINGS.auth.secret.as_str(); + let token = token::create(user, secret).unwrap(); - Ok(token) + Ok(token) } diff --git a/src/utils/authenticate_request.rs b/src/utils/authenticate_request.rs index 2893a1a..bc87bc6 100644 --- a/src/utils/authenticate_request.rs +++ b/src/utils/authenticate_request.rs @@ -1,9 +1,9 @@ use axum::{ - async_trait, - extract::{FromRequestParts, TypedHeader}, - headers::{authorization::Bearer, Authorization}, - http::request::Parts, - RequestPartsExt, + async_trait, + extract::{FromRequestParts, TypedHeader}, + headers::{authorization::Bearer, Authorization}, + http::request::Parts, + RequestPartsExt, }; use crate::errors::AuthenticateError; @@ -15,20 +15,20 @@ use crate::utils::token::TokenUser; #[async_trait] impl FromRequestParts for TokenUser where - S: Send + Sync, + S: Send + Sync, { - type Rejection = Error; + type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let TypedHeader(Authorization(bearer)) = parts - .extract::>>() - .await - .map_err(|_| AuthenticateError::InvalidToken)?; + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let TypedHeader(Authorization(bearer)) = parts + .extract::>>() + .await + .map_err(|_| AuthenticateError::InvalidToken)?; - let secret = SETTINGS.auth.secret.as_str(); - let token_data = - token::decode(bearer.token(), secret).map_err(|_| AuthenticateError::InvalidToken)?; + let secret = SETTINGS.auth.secret.as_str(); + let token_data = + token::decode(bearer.token(), secret).map_err(|_| AuthenticateError::InvalidToken)?; - Ok(token_data.claims.user) - } + Ok(token_data.claims.user) + } } diff --git a/src/utils/custom_response.rs b/src/utils/custom_response.rs index b9cfc3c..09c56b9 100644 --- a/src/utils/custom_response.rs +++ b/src/utils/custom_response.rs @@ -1,7 +1,7 @@ use axum::{ - http::header::{self, HeaderValue}, - http::StatusCode, - response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, + http::header::{self, HeaderValue}, + http::StatusCode, + response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use bytes::{BufMut, BytesMut}; use serde::Serialize; @@ -13,114 +13,111 @@ pub type CustomResponseResult = Result, Error>; #[derive(Debug)] pub struct CustomResponse { - pub body: Option, - pub status_code: StatusCode, - pub pagination: Option, + pub body: Option, + pub status_code: StatusCode, + pub pagination: Option, } pub struct CustomResponseBuilder { - pub body: Option, - pub status_code: StatusCode, - pub pagination: Option, + pub body: Option, + pub status_code: StatusCode, + pub pagination: Option, } #[derive(Debug)] pub struct ResponsePagination { - pub count: u64, - pub offset: u64, - pub limit: u32, + pub count: u64, + pub offset: u64, + pub limit: u32, } impl Default for CustomResponseBuilder where - T: Serialize, + T: Serialize, { - fn default() -> Self { - Self { - body: None, - status_code: StatusCode::OK, - pagination: None, + fn default() -> Self { + Self { + body: None, + status_code: StatusCode::OK, + pagination: None, + } } - } } impl CustomResponseBuilder where - T: Serialize, + T: Serialize, { - pub fn new() -> Self { - Self::default() - } - - pub fn body(mut self, body: T) -> Self { - self.body = Some(body); - self - } - - pub fn status_code(mut self, status_code: StatusCode) -> Self { - self.status_code = status_code; - self - } - - pub fn pagination(mut self, pagination: ResponsePagination) -> Self { - self.pagination = Some(pagination); - self - } - - pub fn build(self) -> CustomResponse { - CustomResponse { - body: self.body, - status_code: self.status_code, - pagination: self.pagination, + pub fn new() -> Self { + Self::default() + } + + pub fn body(mut self, body: T) -> Self { + self.body = Some(body); + self + } + + pub fn status_code(mut self, status_code: StatusCode) -> Self { + self.status_code = status_code; + self + } + + pub fn pagination(mut self, pagination: ResponsePagination) -> Self { + self.pagination = Some(pagination); + self + } + + pub fn build(self) -> CustomResponse { + CustomResponse { + body: self.body, + status_code: self.status_code, + pagination: self.pagination, + } } - } } impl IntoResponse for CustomResponse where - T: Serialize, + T: Serialize, { - fn into_response(self) -> Response { - let body = match self.body { - Some(body) => body, - None => return (self.status_code).into_response(), - }; - - let mut bytes = BytesMut::new().writer(); - if let Err(err) = serde_json::to_writer(&mut bytes, &body) { - error!("Error serializing response body as JSON: {:?}", err); - return (StatusCode::INTERNAL_SERVER_ERROR).into_response(); - } - - let bytes = bytes.into_inner().freeze(); - let headers = [( - header::CONTENT_TYPE, - HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), - )]; - - match self.pagination { - Some(pagination) => (self.status_code, pagination, headers, bytes).into_response(), - None => (self.status_code, headers, bytes).into_response(), + fn into_response(self) -> Response { + let body = match self.body { + Some(body) => body, + None => return (self.status_code).into_response(), + }; + + let mut bytes = BytesMut::new().writer(); + if let Err(err) = serde_json::to_writer(&mut bytes, &body) { + error!("Error serializing response body as JSON: {:?}", err); + return (StatusCode::INTERNAL_SERVER_ERROR).into_response(); + } + + let bytes = bytes.into_inner().freeze(); + let headers = [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), + )]; + + match self.pagination { + Some(pagination) => (self.status_code, pagination, headers, bytes).into_response(), + None => (self.status_code, headers, bytes).into_response(), + } } - } } impl IntoResponseParts for ResponsePagination { - type Error = (StatusCode, String); + type Error = (StatusCode, String); - fn into_response_parts(self, mut res: ResponseParts) -> Result { - res - .headers_mut() - .insert("x-pagination-count", self.count.into()); + fn into_response_parts(self, mut res: ResponseParts) -> Result { + res.headers_mut() + .insert("x-pagination-count", self.count.into()); - res - .headers_mut() - .insert("x-pagination-offset", self.offset.into()); + res.headers_mut() + .insert("x-pagination-offset", self.offset.into()); - res - .headers_mut() - .insert("x-pagination-limit", self.limit.into()); + res.headers_mut() + .insert("x-pagination-limit", self.limit.into()); - Ok(res) - } + Ok(res) + } } diff --git a/src/utils/date.rs b/src/utils/date.rs index 7641cd3..52c9484 100644 --- a/src/utils/date.rs +++ b/src/utils/date.rs @@ -3,5 +3,5 @@ use chrono::Utc; pub type Date = bson::DateTime; pub fn now() -> Date { - Utc::now().into() + Utc::now().into() } diff --git a/src/utils/models.rs b/src/utils/models.rs index 2ae2fad..557d45d 100644 --- a/src/utils/models.rs +++ b/src/utils/models.rs @@ -27,176 +27,176 @@ use crate::errors::Error; #[async_trait] pub trait ModelExt where - Self: WitherModel + Validate, + Self: WitherModel + Validate, { - async fn create(mut model: Self) -> Result { - let connection = database::connection().await; - model.validate().map_err(|_error| Error::bad_request())?; - model.save(connection, None).await.map_err(Error::Wither)?; - - Ok(model) - } - - async fn find_by_id(id: &ObjectId) -> Result, Error> { - let connection = database::connection().await; - ::find_one(connection, doc! { "_id": id }, None) - .await - .map_err(Error::Wither) - } - - async fn find_one(query: Document, options: O) -> Result, Error> - where - O: Into> + Send, - { - let connection = database::connection().await; - ::find_one(connection, query, options) - .await - .map_err(Error::Wither) - } - - async fn find(query: Document, options: O) -> Result, Error> - where - O: Into> + Send, - { - let connection = database::connection().await; - ::find(connection, query, options) - .await - .map_err(Error::Wither)? - .try_collect::>() - .await - .map_err(Error::Wither) - } - - async fn find_and_count(query: Document, options: O) -> Result<(Vec, u64), Error> - where - O: Into> + Send, - { - let connection = database::connection().await; - - let count = Self::collection(connection) - .count_documents(query.clone(), None) - .await - .map_err(Error::Mongo)?; - - let items = ::find(connection, query, options.into()) - .await - .map_err(Error::Wither)? - .try_collect::>() - .await - .map_err(Error::Wither)?; - - Ok((items, count)) - } - - async fn cursor(query: Document, options: O) -> Result, Error> - where - O: Into> + Send, - { - let connection = database::connection().await; - ::find(connection, query, options) - .await - .map_err(Error::Wither) - } - - async fn find_one_and_update(query: Document, update: Document) -> Result, Error> { - let connection = database::connection().await; - let options = FindOneAndUpdateOptions::builder() - .return_document(ReturnDocument::After) - .build(); - - ::find_one_and_update(connection, query, update, options) - .await - .map_err(Error::Wither) - } - - async fn update_one( - query: Document, - update: Document, - options: O, - ) -> Result - where - O: Into> + Send, - { - let connection = database::connection().await; - Self::collection(connection) - .update_one(query, update, options) - .await - .map_err(Error::Mongo) - } - - async fn update_many( - query: Document, - update: Document, - options: O, - ) -> Result - where - O: Into> + Send, - { - let connection = database::connection().await; - Self::collection(connection) - .update_many(query, update, options) - .await - .map_err(Error::Mongo) - } - - async fn delete_many(query: Document) -> Result { - let connection = database::connection().await; - ::delete_many(connection, query, None) - .await - .map_err(Error::Wither) - } - - async fn delete_one(query: Document) -> Result { - let connection = database::connection().await; - Self::collection(connection) - .delete_one(query, None) - .await - .map_err(Error::Mongo) - } - - async fn count(query: Document) -> Result { - let connection = database::connection().await; - Self::collection(connection) - .count_documents(query, None) - .await - .map_err(Error::Mongo) - } - - async fn exists(query: Document) -> Result { - let connection = database::connection().await; - let count = Self::collection(connection) - .count_documents(query, None) - .await - .map_err(Error::Mongo)?; - - Ok(count > 0) - } - - async fn aggregate(pipeline: Vec) -> Result, Error> - where - A: Serialize + DeserializeOwned, - { - let connection = database::connection().await; - - let documents = Self::collection(connection) - .aggregate(pipeline, None) - .await - .map_err(Error::Mongo)? - .try_collect::>() - .await - .map_err(Error::Mongo)?; - - let documents = documents - .into_iter() - .map(|document| from_bson::(Bson::Document(document))) - .collect::, bson::de::Error>>() - .map_err(Error::SerializeMongoResponse)?; - - Ok(documents) - } - - async fn sync_indexes() -> Result<(), Error> { - let connection = database::connection().await; - Self::sync(connection).await.map_err(Error::Wither) - } + async fn create(mut model: Self) -> Result { + let connection = database::connection().await; + model.validate().map_err(|_error| Error::bad_request())?; + model.save(connection, None).await.map_err(Error::Wither)?; + + Ok(model) + } + + async fn find_by_id(id: &ObjectId) -> Result, Error> { + let connection = database::connection().await; + ::find_one(connection, doc! { "_id": id }, None) + .await + .map_err(Error::Wither) + } + + async fn find_one(query: Document, options: O) -> Result, Error> + where + O: Into> + Send, + { + let connection = database::connection().await; + ::find_one(connection, query, options) + .await + .map_err(Error::Wither) + } + + async fn find(query: Document, options: O) -> Result, Error> + where + O: Into> + Send, + { + let connection = database::connection().await; + ::find(connection, query, options) + .await + .map_err(Error::Wither)? + .try_collect::>() + .await + .map_err(Error::Wither) + } + + async fn find_and_count(query: Document, options: O) -> Result<(Vec, u64), Error> + where + O: Into> + Send, + { + let connection = database::connection().await; + + let count = Self::collection(connection) + .count_documents(query.clone(), None) + .await + .map_err(Error::Mongo)?; + + let items = ::find(connection, query, options.into()) + .await + .map_err(Error::Wither)? + .try_collect::>() + .await + .map_err(Error::Wither)?; + + Ok((items, count)) + } + + async fn cursor(query: Document, options: O) -> Result, Error> + where + O: Into> + Send, + { + let connection = database::connection().await; + ::find(connection, query, options) + .await + .map_err(Error::Wither) + } + + async fn find_one_and_update(query: Document, update: Document) -> Result, Error> { + let connection = database::connection().await; + let options = FindOneAndUpdateOptions::builder() + .return_document(ReturnDocument::After) + .build(); + + ::find_one_and_update(connection, query, update, options) + .await + .map_err(Error::Wither) + } + + async fn update_one( + query: Document, + update: Document, + options: O, + ) -> Result + where + O: Into> + Send, + { + let connection = database::connection().await; + Self::collection(connection) + .update_one(query, update, options) + .await + .map_err(Error::Mongo) + } + + async fn update_many( + query: Document, + update: Document, + options: O, + ) -> Result + where + O: Into> + Send, + { + let connection = database::connection().await; + Self::collection(connection) + .update_many(query, update, options) + .await + .map_err(Error::Mongo) + } + + async fn delete_many(query: Document) -> Result { + let connection = database::connection().await; + ::delete_many(connection, query, None) + .await + .map_err(Error::Wither) + } + + async fn delete_one(query: Document) -> Result { + let connection = database::connection().await; + Self::collection(connection) + .delete_one(query, None) + .await + .map_err(Error::Mongo) + } + + async fn count(query: Document) -> Result { + let connection = database::connection().await; + Self::collection(connection) + .count_documents(query, None) + .await + .map_err(Error::Mongo) + } + + async fn exists(query: Document) -> Result { + let connection = database::connection().await; + let count = Self::collection(connection) + .count_documents(query, None) + .await + .map_err(Error::Mongo)?; + + Ok(count > 0) + } + + async fn aggregate(pipeline: Vec) -> Result, Error> + where + A: Serialize + DeserializeOwned, + { + let connection = database::connection().await; + + let documents = Self::collection(connection) + .aggregate(pipeline, None) + .await + .map_err(Error::Mongo)? + .try_collect::>() + .await + .map_err(Error::Mongo)?; + + let documents = documents + .into_iter() + .map(|document| from_bson::(Bson::Document(document))) + .collect::, bson::de::Error>>() + .map_err(Error::SerializeMongoResponse)?; + + Ok(documents) + } + + async fn sync_indexes() -> Result<(), Error> { + let connection = database::connection().await; + Self::sync(connection).await.map_err(Error::Wither) + } } diff --git a/src/utils/pagination.rs b/src/utils/pagination.rs index 2bd1562..0aac508 100644 --- a/src/utils/pagination.rs +++ b/src/utils/pagination.rs @@ -8,44 +8,44 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] struct Limit { - limit: u32, + limit: u32, } impl Default for Limit { - fn default() -> Self { - Self { limit: 100 } - } + fn default() -> Self { + Self { limit: 100 } + } } #[derive(Debug, Clone, Default, Deserialize)] struct Offset { - offset: u64, + offset: u64, } #[derive(Debug, Clone)] pub struct Pagination { - /// The number of documents to skip before counting. - pub offset: u64, - /// The maximum number of documents to query. - pub limit: u32, + /// The number of documents to skip before counting. + pub offset: u64, + /// The maximum number of documents to query. + pub limit: u32, } #[async_trait] impl FromRequestParts for Pagination where - S: Send + Sync, + S: Send + Sync, { - type Rejection = (StatusCode, &'static str); + type Rejection = (StatusCode, &'static str); - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Query(Limit { limit }) = Query::::from_request_parts(parts, state) - .await - .unwrap_or_default(); + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let Query(Limit { limit }) = Query::::from_request_parts(parts, state) + .await + .unwrap_or_default(); - let Query(Offset { offset }) = Query::::from_request_parts(parts, state) - .await - .unwrap_or_default(); + let Query(Offset { offset }) = Query::::from_request_parts(parts, state) + .await + .unwrap_or_default(); - Ok(Self { limit, offset }) - } + Ok(Self { limit, offset }) + } } diff --git a/src/utils/to_object_id.rs b/src/utils/to_object_id.rs index 45590f8..979eea7 100644 --- a/src/utils/to_object_id.rs +++ b/src/utils/to_object_id.rs @@ -3,5 +3,5 @@ use bson::oid::ObjectId; use crate::errors::Error; pub fn to_object_id>(id: S) -> Result { - ObjectId::parse_str(id.as_ref()).map_err(|_| Error::ParseObjectID(id.as_ref().to_string())) + ObjectId::parse_str(id.as_ref()).map_err(|_| Error::ParseObjectID(id.as_ref().to_string())) } diff --git a/src/utils/token.rs b/src/utils/token.rs index 270b8b4..25bc237 100644 --- a/src/utils/token.rs +++ b/src/utils/token.rs @@ -12,47 +12,47 @@ static HEADER: Lazy

= Lazy::new(Header::default); #[derive(Debug, Serialize, Deserialize)] pub struct TokenUser { - pub id: ObjectId, - pub name: String, - pub email: String, + pub id: ObjectId, + pub name: String, + pub email: String, } impl From for TokenUser { - fn from(user: User) -> Self { - Self { - id: user.id.unwrap(), - name: user.name.clone(), - email: user.email, + fn from(user: User) -> Self { + Self { + id: user.id.unwrap(), + name: user.name.clone(), + email: user.email, + } } - } } #[derive(Debug, Serialize, Deserialize)] pub struct Claims { - pub exp: usize, // Expiration time (as UTC timestamp). validate_exp defaults to true in validation - pub iat: usize, // Issued at (as UTC timestamp) - pub user: TokenUser, + pub exp: usize, // Expiration time (as UTC timestamp). validate_exp defaults to true in validation + pub iat: usize, // Issued at (as UTC timestamp) + pub user: TokenUser, } impl Claims { - pub fn new(user: User) -> Self { - Self { - exp: (chrono::Local::now() + chrono::Duration::days(30)).timestamp() as usize, - iat: chrono::Local::now().timestamp() as usize, - user: TokenUser::from(user), + pub fn new(user: User) -> Self { + Self { + exp: (chrono::Local::now() + chrono::Duration::days(30)).timestamp() as usize, + iat: chrono::Local::now().timestamp() as usize, + user: TokenUser::from(user), + } } - } } pub fn create(user: User, secret: &str) -> Result { - let encoding_key = EncodingKey::from_secret(secret.as_ref()); - let claims = Claims::new(user); + let encoding_key = EncodingKey::from_secret(secret.as_ref()); + let claims = Claims::new(user); - jsonwebtoken::encode(&HEADER, &claims, &encoding_key) + jsonwebtoken::encode(&HEADER, &claims, &encoding_key) } pub fn decode(token: &str, secret: &str) -> TokenResult { - let decoding_key = DecodingKey::from_secret(secret.as_ref()); + let decoding_key = DecodingKey::from_secret(secret.as_ref()); - jsonwebtoken::decode::(token, &decoding_key, &VALIDATION) + jsonwebtoken::decode::(token, &decoding_key, &VALIDATION) }