diff --git a/Cargo.lock b/Cargo.lock index 03512c36a3..0514536825 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2013,6 +2013,14 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "macros" +version = "0.1.0" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "match_cfg" version = "0.1.0" @@ -4030,6 +4038,7 @@ dependencies = [ "lettre", "libsqlite3-sys", "log", + "macros", "mimalloc", "num-derive", "num-traits", diff --git a/Cargo.toml b/Cargo.toml index 9ab1baad67..9ea3dc5c56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,5 @@ +workspace = { members = ["macros"] } + [package] name = "vaultwarden" version = "1.0.0" @@ -39,6 +41,8 @@ unstable = [] syslog = "7.0.0" [dependencies] +macros = { path = "./macros" } + # Logging log = "0.4.22" fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] } diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 0000000000..184accfc5c --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "macros" +version = "0.1.0" +edition = "2021" + +[lib] +name = "macros" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +quote = "1.0.38" +syn = "2.0.94" diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 0000000000..2c4e297a3b --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,31 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; + +#[proc_macro_derive(IdFromParam)] +pub fn derive_from_param(input: TokenStream) -> TokenStream { + let ast = syn::parse(input).unwrap(); + + impl_derive_macro(&ast) +} + +fn impl_derive_macro(ast: &syn::DeriveInput) -> TokenStream { + let name = &ast.ident; + let gen = quote! { + #[automatically_derived] + impl<'r> rocket::request::FromParam<'r> for #name { + type Error = (); + + #[inline(always)] + fn from_param(param: &'r str) -> Result { + if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { + Ok(Self(param.to_string())) + } else { + Err(()) + } + } + } + }; + gen.into() +} diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs index c86bd636e7..09348f78b1 100644 --- a/src/db/models/attachment.rs +++ b/src/db/models/attachment.rs @@ -2,11 +2,11 @@ use std::io::ErrorKind; use bigdecimal::{BigDecimal, ToPrimitive}; use derive_more::{AsRef, Deref, Display}; -use rocket::request::FromParam; use serde_json::Value; use super::{CipherId, OrganizationId, UserId}; use crate::CONFIG; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -230,18 +230,19 @@ impl Attachment { } } -#[derive(Clone, Debug, AsRef, Deref, DieselNewType, Display, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive( + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, +)] pub struct AttachmentId(pub String); - -impl<'r> FromParam<'r> for AttachmentId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/auth_request.rs b/src/db/models/auth_request.rs index 3417d07eca..6a04c57da6 100644 --- a/src/db/models/auth_request.rs +++ b/src/db/models/auth_request.rs @@ -2,7 +2,7 @@ use super::{DeviceId, OrganizationId, UserId}; use crate::crypto::ct_eq; use chrono::{NaiveDateTime, Utc}; use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; +use macros::IdFromParam; db_object! { #[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)] @@ -162,19 +162,19 @@ impl AuthRequest { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct AuthRequestId(String); - -impl<'r> FromParam<'r> for AuthRequestId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/cipher.rs b/src/db/models/cipher.rs index 447ca5b699..af34262ca3 100644 --- a/src/db/models/cipher.rs +++ b/src/db/models/cipher.rs @@ -2,15 +2,14 @@ use crate::util::LowerCase; use crate::CONFIG; use chrono::{NaiveDateTime, TimeDelta, Utc}; use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; use serde_json::Value; use super::{ Attachment, CollectionCipher, CollectionId, Favorite, FolderCipher, FolderId, Group, Membership, MembershipStatus, MembershipType, OrganizationId, User, UserId, }; - use crate::api::core::{CipherData, CipherSyncData, CipherSyncType}; +use macros::IdFromParam; use std::borrow::Cow; @@ -721,7 +720,11 @@ impl Cipher { }} } - pub async fn find_by_uuid_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &mut DbConn) -> Option { + pub async fn find_by_uuid_and_org( + cipher_uuid: &CipherId, + org_uuid: &OrganizationId, + conn: &mut DbConn, + ) -> Option { db_run! {conn: { ciphers::table .filter(ciphers::uuid.eq(cipher_uuid)) @@ -1055,19 +1058,19 @@ impl Cipher { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct CipherId(String); - -impl<'r> FromParam<'r> for CipherId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/collection.rs b/src/db/models/collection.rs index 59d29eb267..d0e8a46565 100644 --- a/src/db/models/collection.rs +++ b/src/db/models/collection.rs @@ -1,5 +1,4 @@ use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; use serde_json::Value; use super::{ @@ -7,6 +6,7 @@ use super::{ User, UserId, }; use crate::CONFIG; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -815,19 +815,19 @@ impl From for CollectionMembership { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct CollectionId(String); - -impl<'r> FromParam<'r> for CollectionId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/device.rs b/src/db/models/device.rs index 69c96bec0d..0f1afd0fc0 100644 --- a/src/db/models/device.rs +++ b/src/db/models/device.rs @@ -1,9 +1,9 @@ use chrono::{NaiveDateTime, Utc}; use derive_more::{Display, From}; -use rocket::request::FromParam; use super::UserId; use crate::{crypto, CONFIG}; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -335,24 +335,7 @@ impl DeviceType { } } -#[derive(Clone, Debug, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive( + Clone, Debug, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, IdFromParam, +)] pub struct DeviceId(String); - -impl DeviceId { - pub fn empty() -> Self { - Self(String::from("00000000-0000-0000-0000-000000000000")) - } -} - -impl<'r> FromParam<'r> for DeviceId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/emergency_access.rs b/src/db/models/emergency_access.rs index 29c688b990..7a62bdb4bc 100644 --- a/src/db/models/emergency_access.rs +++ b/src/db/models/emergency_access.rs @@ -1,11 +1,10 @@ use chrono::{NaiveDateTime, Utc}; use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; use serde_json::Value; -use crate::{api::EmptyResult, db::DbConn, error::MapResult}; - use super::{User, UserId}; +use crate::{api::EmptyResult, db::DbConn, error::MapResult}; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -365,19 +364,19 @@ impl EmergencyAccess { // endregion #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct EmergencyAccessId(String); - -impl<'r> FromParam<'r> for EmergencyAccessId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/folder.rs b/src/db/models/folder.rs index ea7208bd99..deb8819eb0 100644 --- a/src/db/models/folder.rs +++ b/src/db/models/folder.rs @@ -1,9 +1,9 @@ use chrono::{NaiveDateTime, Utc}; use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; use serde_json::Value; use super::{CipherId, User, UserId}; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -235,19 +235,19 @@ impl FolderCipher { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct FolderId(String); - -impl<'r> FromParam<'r> for FolderId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/group.rs b/src/db/models/group.rs index 6720e9b7ce..ca3a9acb3b 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -4,7 +4,7 @@ use crate::db::DbConn; use crate::error::MapResult; use chrono::{NaiveDateTime, Utc}; use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; +use macros::IdFromParam; use serde_json::Value; db_object! { @@ -605,19 +605,19 @@ impl GroupUser { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct GroupId(String); - -impl<'r> FromParam<'r> for GroupId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index 7ad9b970cb..ea50e56750 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -1,7 +1,6 @@ use chrono::{NaiveDateTime, Utc}; use derive_more::{AsRef, Deref, Display, From}; use num_traits::FromPrimitive; -use rocket::request::FromParam; use serde_json::Value; use std::{ cmp::Ordering, @@ -13,6 +12,7 @@ use super::{ OrgPolicyType, TwoFactor, User, UserId, }; use crate::CONFIG; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -1121,41 +1121,42 @@ impl OrganizationApiKey { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] #[deref(forward)] #[from(forward)] pub struct OrganizationId(String); -impl<'r> FromParam<'r> for OrganizationId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} - -#[derive(Clone, Debug, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive( + Clone, + Debug, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, +)] pub struct MembershipId(String); -impl<'r> FromParam<'r> for MembershipId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -} - #[derive(Clone, Debug, DieselNewType, Display, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct OrgApiKeyId(String); diff --git a/src/db/models/send.rs b/src/db/models/send.rs index f1ba0c0cca..ee7685a340 100644 --- a/src/db/models/send.rs +++ b/src/db/models/send.rs @@ -353,12 +353,25 @@ impl Send { // separate namespace to avoid name collision with std::marker::Send pub mod id { use derive_more::{AsRef, Deref, Display, From}; - use rocket::request::FromParam; + use macros::IdFromParam; use std::marker::Send; use std::path::Path; #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + AsRef, + Deref, + DieselNewType, + Display, + From, + FromForm, + Hash, + PartialEq, + Eq, + Serialize, + Deserialize, + IdFromParam, )] pub struct SendId(String); @@ -369,20 +382,9 @@ pub mod id { } } - impl<'r> FromParam<'r> for SendId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } - } - - #[derive(Clone, Debug, AsRef, Deref, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize)] + #[derive( + Clone, Debug, AsRef, Deref, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, IdFromParam, + )] pub struct SendFileId(String); impl AsRef for SendFileId { @@ -391,17 +393,4 @@ pub mod id { Path::new(&self.0) } } - - impl<'r> FromParam<'r> for SendFileId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } - } } diff --git a/src/db/models/user.rs b/src/db/models/user.rs index 411b9ac72a..302b2c9d86 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -1,6 +1,5 @@ use chrono::{NaiveDateTime, TimeDelta, Utc}; use derive_more::{AsRef, Deref, Display, From}; -use rocket::request::FromParam; use serde_json::Value; use super::{ @@ -14,6 +13,7 @@ use crate::{ util::{format_date, get_uuid, retry}, CONFIG, }; +use macros::IdFromParam; db_object! { #[derive(Identifiable, Queryable, Insertable, AsChangeset)] @@ -460,21 +460,21 @@ impl Invitation { } #[derive( - Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize, + Clone, + Debug, + DieselNewType, + FromForm, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + AsRef, + Deref, + Display, + From, + IdFromParam, )] #[deref(forward)] #[from(forward)] pub struct UserId(String); - -impl<'r> FromParam<'r> for UserId { - type Error = (); - - #[inline(always)] - fn from_param(param: &'r str) -> Result { - if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { - Ok(Self(param.to_string())) - } else { - Err(()) - } - } -}