Skip to content

Commit

Permalink
Add example on how to use Api Gateway authorizers in Axum.
Browse files Browse the repository at this point in the history
This also shows how to work with the RequestExt trait and the RequestContext object.

Signed-off-by: David Calavera <[email protected]>
  • Loading branch information
calavera committed Feb 21, 2024
1 parent 7a3ab97 commit 93a0c3f
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 0 deletions.
14 changes: 14 additions & 0 deletions examples/http-axum-apigw-authorizer/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "http-axum-apigw-authorizer"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.7"
lambda_http = { path = "../../lambda-http" }
lambda_runtime = { path = "../../lambda-runtime" }
serde = "1.0.196"
serde_json = "1.0"
tokio = { version = "1", features = ["macros"] }
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] }
13 changes: 13 additions & 0 deletions examples/http-axum-apigw-authorizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Axum example that integrates with Api Gateway authorizers

This example shows how to extract information from the Api Gateway Request Authorizer in an Axum handler.

## Build & Deploy

1. Install [cargo-lambda](https://github.com/cargo-lambda/cargo-lambda#installation)
2. Build the function with `cargo lambda build --release`
3. Deploy the function to AWS Lambda with `cargo lambda deploy --iam-role YOUR_ROLE`

## Build for ARM 64

Build the function with `cargo lambda build --release --arm64`
80 changes: 80 additions & 0 deletions examples/http-axum-apigw-authorizer/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use axum::{
async_trait,
extract::{FromRequest, Request},
http::StatusCode,
response::Json,
routing::get,
Router,
};
use lambda_http::{run, Error, RequestExt};
use serde_json::{json, Value};
use std::{collections::HashMap, env::set_var};

struct AuthorizerField(String);
struct AuthorizerFields(HashMap<String, serde_json::Value>);

#[async_trait]
impl<S> FromRequest<S> for AuthorizerField
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);

async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
req.request_context_ref()
.and_then(|r| r.authorizer())
.and_then(|a| a.fields.get("field_name"))
.and_then(|f| f.as_str())
.map(|v| Self(v.to_string()))
.ok_or_else(|| (StatusCode::BAD_REQUEST, "`field_name` authorizer field is missing"))
}
}

#[async_trait]
impl<S> FromRequest<S> for AuthorizerFields
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);

async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
req.request_context_ref()
.and_then(|r| r.authorizer())
.map(|a| Self(a.fields.clone()))
.ok_or_else(|| (StatusCode::BAD_REQUEST, "authorizer is missing"))
}
}

async fn extract_field(AuthorizerField(field): AuthorizerField) -> Json<Value> {
Json(json!({ "field extracted": field }))
}

async fn extract_all_fields(AuthorizerFields(fields): AuthorizerFields) -> Json<Value> {
Json(json!({ "authorizer fields": fields }))
}

#[tokio::main]
async fn main() -> Result<(), Error> {
// If you use API Gateway stages, the Rust Runtime will include the stage name
// as part of the path that your application receives.
// Setting the following environment variable, you can remove the stage from the path.
// This variable only applies to API Gateway stages,
// you can remove it if you don't use them.
// i.e with: `GET /test-stage/todo/id/123` without: `GET /todo/id/123`
set_var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH", "true");

// required to enable CloudWatch error logging by the runtime
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
// disable printing the name of the module in every log line.
.with_target(false)
// disabling time is handy because CloudWatch will add the ingestion time.
.without_time()
.init();

let app = Router::new()
.route("/extract-field", get(extract_field))
.route("/extract-all-fields", get(extract_all_fields));

run(app).await
}

0 comments on commit 93a0c3f

Please sign in to comment.