Skip to content

Commit

Permalink
Added ability to return an error for failed SecuritySchema checker. (
Browse files Browse the repository at this point in the history
…#625)

* feat: added ability for `securityscheme` checker to return `option` or `result`
  • Loading branch information
NexRX authored Aug 10, 2023
1 parent 10dd28f commit dccbd34
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 27 deletions.
18 changes: 11 additions & 7 deletions poem-openapi-derive/src/security_scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,16 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult<TokenStream> {
let register_security_scheme =
args.generate_register_security_scheme(&crate_name, &oai_typename)?;
let from_request = args.generate_from_request(&crate_name);
let checker = args.checker.as_ref().map(|path| {
quote! {
let output = ::std::option::Option::ok_or(#path(&req, output).await, #crate_name::error::AuthorizationError)?;
}
});
let path = args.checker.as_ref();

let output = match path {
Some(_) => quote! {
let output = #crate_name::__private::CheckerReturn::from(#path(&req, #from_request?).await).into_result()?;
},
None => quote! {
let output = #from_request?;
},
};

let expanded = quote! {
#[#crate_name::__private::poem::async_trait]
Expand All @@ -468,8 +473,7 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult<TokenStream> {
_param_opts: #crate_name::ExtractParamOptions<Self::ParamType>,
) -> #crate_name::__private::poem::Result<Self> {
let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap();
let output = #from_request?;
#checker
#output
::std::result::Result::Ok(Self(output))
}
}
Expand Down
32 changes: 31 additions & 1 deletion poem-openapi/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod bearer;
use poem::{Request, Result};

pub use self::{api_key::ApiKey, basic::Basic, bearer::Bearer};
use crate::{base::UrlQuery, registry::MetaParamIn};
use crate::{base::UrlQuery, error::AuthorizationError, registry::MetaParamIn};

/// Represents a basic authorization extractor.
pub trait BasicAuthorization: Sized {
Expand All @@ -31,3 +31,33 @@ pub trait ApiKeyAuthorization: Sized {
in_type: MetaParamIn,
) -> Result<Self>;
}

/// Facilitates the conversion of `Option` into `Results`, for `SecuritySchema` checker.
#[doc(hidden)]
pub enum CheckerReturn<T> {
Result(Result<T>),
Option(Option<T>),
}

impl<T> CheckerReturn<T> {
pub fn into_result(self) -> Result<T> {
match self {
Self::Result(result) => result,
Self::Option(option) => option.ok_or(AuthorizationError.into()),
}
}
}

impl<T> From<poem::Result<T>> for CheckerReturn<T> {
#[inline]
fn from(result: Result<T>) -> Self {
Self::Result(result)
}
}

impl<T> From<Option<T>> for CheckerReturn<T> {
#[inline]
fn from(option: Option<T>) -> Self {
Self::Option(option)
}
}
33 changes: 16 additions & 17 deletions poem-openapi/src/docs/security_scheme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,31 @@ Define a OpenAPI Security Scheme.

# Macro parameters

| Attribute | Description | Type | Optional |
|--------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|----------|
| rename | Rename the security scheme. | string | Y |
| ty | The type of the security scheme. (api_key, basic, bearer, oauth2, openid_connect) | string | N |
| key_in | `api_key` The location of the API key. Valid values are "query", "header" or "cookie". (query, header, cookie) | string | Y |
| key_name | `api_key` The name of the header, query or cookie parameter to be used.. | string | Y |
| bearer_format | `bearer` A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes. | string | Y |
| flows | `oauth2` An object containing configuration information for the flow types supported. | OAuthFlows | Y |
| openid_connect_url | OpenId Connect URL to discover OAuth2 configuration values. | string | Y |
| checker | Specify a function to check the original authentication information and convert it to the return type of this function. This function must return `Option<T>`, and return `None` if check fails. | string | Y |
| Attribute | Description | Type | Optional |
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------- | -------- |
| rename | Rename the security scheme. | string | Y |
| ty | The type of the security scheme. (api_key, basic, bearer, oauth2, openid_connect) | string | N |
| key_in | `api_key` The location of the API key. Valid values are "query", "header" or "cookie". (query, header, cookie) | string | Y |
| key_name | `api_key` The name of the header, query or cookie parameter to be used.. | string | Y |
| bearer_format | `bearer` A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes. | string | Y |
| flows | `oauth2` An object containing configuration information for the flow types supported. | OAuthFlows | Y |
| openid_connect_url | OpenId Connect URL to discover OAuth2 configuration values. | string | Y |
| checker | Specify a function to check the original authentication information and convert it to the return type of this function. This function must return `Option<T>` or `poem::Result<T>`, with `None` meaning a General Authorization error and an `Err` reflecting the error supplied. | string | Y |

# OAuthFlows

| Attribute | description | Type | Optional |
|--------------------|----------------------------------------------------------|-----------|----------|
| ------------------ | -------------------------------------------------------- | --------- | -------- |
| implicit | Configuration for the OAuth Implicit flow | OAuthFlow | Y |
| password | Configuration for the OAuth Resource Owner Password flow | OAuthFlow | Y |
| client_credentials | Configuration for the OAuth Client Credentials flow | OAuthFlow | Y |
| authorization_code | Configuration for the OAuth Authorization Code flow | OAuthFlow | Y |

# OAuthFlow

| Attribute | description | Type | Optional |
|-------------------|----------------------------------------------------------------------------------------------|-------------|----------|
| authorization_url | `implicit` `authorization_code` The authorization URL to be used for this flow. | string | Y |
| Attribute | description | Type | Optional |
| ----------------- | -------------------------------------------------------------------------------------------------- | ----------- | -------- |
| authorization_url | `implicit` `authorization_code` The authorization URL to be used for this flow. | string | Y |
| token_url | `password` `client_credentials` `authorization_code` The token URL to be used for this flow. | string | Y |
| refresh_url | The URL to be used for obtaining refresh tokens. | string | Y |
| scopes | The available scopes for the OAuth2 security scheme. | OAuthScopes | Y |

| refresh_url | The URL to be used for obtaining refresh tokens. | string | Y |
| scopes | The available scopes for the OAuth2 security scheme. | OAuthScopes | Y |
2 changes: 1 addition & 1 deletion poem-openapi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,5 @@ pub mod __private {
pub use serde;
pub use serde_json;

pub use crate::base::UrlQuery;
pub use crate::{auth::CheckerReturn, base::UrlQuery};
}
1 change: 1 addition & 0 deletions poem-openapi/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ mod tests {

#[test]
#[allow(clippy::assertions_on_constants)]
#[allow(unused_allocation)]
fn box_type() {
assert!(Box::<i32>::IS_REQUIRED);
assert_eq!(Box::<i32>::name(), "integer(int32)");
Expand Down
105 changes: 104 additions & 1 deletion poem-openapi/tests/security_scheme.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use poem::{
http::header,
error::ResponseError,
http::{header, StatusCode},
test::TestClient,
web::{cookie::Cookie, headers},
Request,
};
use poem_openapi::{
auth::{ApiKey, Basic, Bearer},
Expand Down Expand Up @@ -435,3 +437,104 @@ async fn oauth2_auth() {
}
);
}

#[tokio::test]
async fn checker_result() {
#[derive(SecurityScheme)]
#[oai(rename = "Checker Option", ty = "basic", checker = "extract_string")]
struct MySecurityScheme(Basic);

#[derive(Debug, thiserror::Error)]
#[error("Your account is disabled")]
struct AccountDisabledError;

impl ResponseError for AccountDisabledError {
fn status(&self) -> StatusCode {
StatusCode::FORBIDDEN
}
}

async fn extract_string(_req: &Request, basic: Basic) -> poem::Result<Basic> {
if basic.username != "Disabled" {
Ok(basic)
} else {
Err(AccountDisabledError)?
}
}

let mut registry = Registry::new();
MySecurityScheme::register(&mut registry);

struct MyApi;

#[OpenApi]
impl MyApi {
#[oai(path = "/test", method = "get")]
async fn test(&self, auth: MySecurityScheme) -> PlainText<String> {
PlainText(format!("Authed: {}", auth.0.username))
}
}

let service = OpenApiService::new(MyApi, "test", "1.0");
let client = TestClient::new(service);
let resp = client
.get("/test")
.typed_header(headers::Authorization::basic("Enabled", "password"))
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("Authed: Enabled".to_string()).await;

let resp = client
.get("/test")
.typed_header(headers::Authorization::basic("Disabled", "password"))
.send()
.await;
resp.assert_status(StatusCode::FORBIDDEN);
resp.assert_text("Your account is disabled").await;
}

#[tokio::test]
async fn checker_option() {
#[derive(SecurityScheme)]
#[oai(rename = "Checker Option", ty = "basic", checker = "extract_string")]
struct MySecurityScheme(Basic);

async fn extract_string(_req: &Request, basic: Basic) -> Option<Basic> {
if basic.username != "Disabled" {
Some(basic)
} else {
None
}
}

let mut registry = Registry::new();
MySecurityScheme::register(&mut registry);

struct MyApi;

#[OpenApi]
impl MyApi {
#[oai(path = "/test", method = "get")]
async fn test(&self, auth: MySecurityScheme) -> PlainText<String> {
PlainText(format!("Authed: {}", auth.0.username))
}
}

let service = OpenApiService::new(MyApi, "test", "1.0");
let client = TestClient::new(service);
let resp = client
.get("/test")
.typed_header(headers::Authorization::basic("Enabled", "password"))
.send()
.await;
resp.assert_status_is_ok();
resp.assert_text("Authed: Enabled".to_string()).await;

let resp = client
.get("/test")
.typed_header(headers::Authorization::basic("Disabled", "password"))
.send()
.await;
resp.assert_status(StatusCode::UNAUTHORIZED);
}

0 comments on commit dccbd34

Please sign in to comment.