Skip to content

Commit 93a0c3f

Browse files
committed
Add example on how to use Api Gateway authorizers in Axum.
This also shows how to work with the RequestExt trait and the RequestContext object. Signed-off-by: David Calavera <[email protected]>
1 parent 7a3ab97 commit 93a0c3f

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[package]
2+
name = "http-axum-apigw-authorizer"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
axum = "0.7"
8+
lambda_http = { path = "../../lambda-http" }
9+
lambda_runtime = { path = "../../lambda-runtime" }
10+
serde = "1.0.196"
11+
serde_json = "1.0"
12+
tokio = { version = "1", features = ["macros"] }
13+
tracing = { version = "0.1", features = ["log"] }
14+
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] }
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Axum example that integrates with Api Gateway authorizers
2+
3+
This example shows how to extract information from the Api Gateway Request Authorizer in an Axum handler.
4+
5+
## Build & Deploy
6+
7+
1. Install [cargo-lambda](https://github.com/cargo-lambda/cargo-lambda#installation)
8+
2. Build the function with `cargo lambda build --release`
9+
3. Deploy the function to AWS Lambda with `cargo lambda deploy --iam-role YOUR_ROLE`
10+
11+
## Build for ARM 64
12+
13+
Build the function with `cargo lambda build --release --arm64`
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use axum::{
2+
async_trait,
3+
extract::{FromRequest, Request},
4+
http::StatusCode,
5+
response::Json,
6+
routing::get,
7+
Router,
8+
};
9+
use lambda_http::{run, Error, RequestExt};
10+
use serde_json::{json, Value};
11+
use std::{collections::HashMap, env::set_var};
12+
13+
struct AuthorizerField(String);
14+
struct AuthorizerFields(HashMap<String, serde_json::Value>);
15+
16+
#[async_trait]
17+
impl<S> FromRequest<S> for AuthorizerField
18+
where
19+
S: Send + Sync,
20+
{
21+
type Rejection = (StatusCode, &'static str);
22+
23+
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
24+
req.request_context_ref()
25+
.and_then(|r| r.authorizer())
26+
.and_then(|a| a.fields.get("field_name"))
27+
.and_then(|f| f.as_str())
28+
.map(|v| Self(v.to_string()))
29+
.ok_or_else(|| (StatusCode::BAD_REQUEST, "`field_name` authorizer field is missing"))
30+
}
31+
}
32+
33+
#[async_trait]
34+
impl<S> FromRequest<S> for AuthorizerFields
35+
where
36+
S: Send + Sync,
37+
{
38+
type Rejection = (StatusCode, &'static str);
39+
40+
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
41+
req.request_context_ref()
42+
.and_then(|r| r.authorizer())
43+
.map(|a| Self(a.fields.clone()))
44+
.ok_or_else(|| (StatusCode::BAD_REQUEST, "authorizer is missing"))
45+
}
46+
}
47+
48+
async fn extract_field(AuthorizerField(field): AuthorizerField) -> Json<Value> {
49+
Json(json!({ "field extracted": field }))
50+
}
51+
52+
async fn extract_all_fields(AuthorizerFields(fields): AuthorizerFields) -> Json<Value> {
53+
Json(json!({ "authorizer fields": fields }))
54+
}
55+
56+
#[tokio::main]
57+
async fn main() -> Result<(), Error> {
58+
// If you use API Gateway stages, the Rust Runtime will include the stage name
59+
// as part of the path that your application receives.
60+
// Setting the following environment variable, you can remove the stage from the path.
61+
// This variable only applies to API Gateway stages,
62+
// you can remove it if you don't use them.
63+
// i.e with: `GET /test-stage/todo/id/123` without: `GET /todo/id/123`
64+
set_var("AWS_LAMBDA_HTTP_IGNORE_STAGE_IN_PATH", "true");
65+
66+
// required to enable CloudWatch error logging by the runtime
67+
tracing_subscriber::fmt()
68+
.with_max_level(tracing::Level::INFO)
69+
// disable printing the name of the module in every log line.
70+
.with_target(false)
71+
// disabling time is handy because CloudWatch will add the ingestion time.
72+
.without_time()
73+
.init();
74+
75+
let app = Router::new()
76+
.route("/extract-field", get(extract_field))
77+
.route("/extract-all-fields", get(extract_all_fields));
78+
79+
run(app).await
80+
}

0 commit comments

Comments
 (0)