diff --git a/actix-web-grants/src/guards.rs b/actix-web-grants/src/guards.rs index 33705b3..a8da008 100644 --- a/actix-web-grants/src/guards.rs +++ b/actix-web-grants/src/guards.rs @@ -17,7 +17,7 @@ use std::hash::Hash; /// .wrap(GrantsMiddleware::with_extractor(extract)) /// .service(web::resource("/admin") /// .to(|| async { HttpResponse::Ok().finish() }) -/// .guard(AuthorityGuard::new("ROLE_ADMIN".to_string()))) +/// .guard(AuthorityGuard::contains("ROLE_ADMIN".to_string()))) /// }); /// } /// @@ -29,22 +29,57 @@ use std::hash::Hash; /// Ok(HashSet::from(["ROLE_ADMIN".to_string()])) /// } /// ``` -pub struct AuthorityGuard { - allow_authority: Type, + +pub struct AuthorityGuard { + allow_authority: Type, +} + +pub enum Type { + Single(T), + Any(Vec), + All(Vec), } -impl AuthorityGuard { - pub fn new(allow_authority: Type) -> AuthorityGuard { - AuthorityGuard { allow_authority } +impl AuthorityGuard { + fn create(allow_authority: Type) -> AuthorityGuard { + AuthorityGuard { + allow_authority: allow_authority, + } + } + + #[deprecated] + pub fn new(allow_authority: T) -> AuthorityGuard { + Self::contains(allow_authority) + } + + pub fn contains(allow_authority: T) -> AuthorityGuard { + Self::create(Type::Single(allow_authority)) + } + + pub fn all(allow_authority: impl IntoIterator) -> AuthorityGuard { + Self::create(Type::All(allow_authority.into_iter().collect())) + } + + pub fn any(allow_authority: impl IntoIterator) -> AuthorityGuard { + Self::create(Type::Any(allow_authority.into_iter().collect())) } } -impl Guard for AuthorityGuard { +impl Guard for AuthorityGuard { fn check(&self, request: &GuardContext) -> bool { - request - .req_data() - .get::>() - .filter(|details| details.has_authority(&self.allow_authority)) - .is_some() + let req_data = request.req_data(); + let details = req_data.get::>(); + match &self.allow_authority { + Type::Single(s) => details + .filter(|details| details.has_authority(&s)) + .is_some(), + Type::Any(items) => details + .filter(|details| details.has_any_authority(&items.iter().collect::>())) + .is_some(), + Type::All(items) => details + .filter(|details| details.has_authorities(&items.iter().collect::>())) + .is_some(), + } } } +