Skip to content

Commit

Permalink
feat: add more host, authorization header support
Browse files Browse the repository at this point in the history
  • Loading branch information
kane50613 committed Dec 8, 2024
1 parent e2936f1 commit e49e580
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 35 deletions.
46 changes: 39 additions & 7 deletions src/handlers/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use std::sync::LazyLock;

use axum::{
extract::Request,
http::{HeaderValue, Uri},
http::{Method, Uri},
response::Response,
};
use reqwest::{
header::{AUTHORIZATION, CONTENT_TYPE, HOST},
header::{CONTENT_TYPE, HOST},
Client,
};

Expand All @@ -17,9 +17,14 @@ use crate::{

pub static HTTP: LazyLock<Client> = LazyLock::new(Client::default);

pub const DISCORD_HOST: HeaderValue = HeaderValue::from_static("discord.com");

pub async fn api_handler(request: Request) -> Response {
if request.method() != Method::GET && request.method() != Method::DELETE {
return Response::builder()
.status(405)
.body("Method Not Allowed".into())
.unwrap();
}

let (mut head, _) = request.into_parts();
let uri = head.uri.into_parts();

Expand All @@ -30,24 +35,51 @@ pub async fn api_handler(request: Request) -> Response {
.unwrap();
};

let host = head
.headers
.get("x-host")
.and_then(|host| host.to_str().ok())
.unwrap_or("discord.com")
.to_lowercase();

let authorization_header = head
.headers
.get("x-authorization-name")
.map(|x| x.to_str().unwrap())
.unwrap_or("authorization")
.to_lowercase();

let authorization = head
.headers
.get(&authorization_header)
.and_then(|x| x.to_str().ok());

let cache_key = create_cache_key(
head.method.as_str().as_bytes(),
path.as_str().as_bytes(),
head.headers.get(AUTHORIZATION).map(|x| x.as_bytes()),
host.as_bytes(),
authorization_header.as_bytes(),
authorization.map(|x| x.as_bytes()),
);

if head.method == Method::DELETE {
DB.delete(cache_key).await;

return Response::builder().status(200).body("OK".into()).unwrap();
}

if let Some(cache_response) = DB.get(cache_key).await {
return cache_response.into();
}

let url = Uri::builder()
.authority("discord.com")
.authority(host.as_str())
.scheme("https")
.path_and_query(path)
.build()
.unwrap();

head.headers.insert(HOST, DISCORD_HOST);
head.headers.insert(HOST, host.parse().unwrap());

let mut response = HTTP
.get(url.to_string())
Expand Down
16 changes: 0 additions & 16 deletions src/handlers/invalidate.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
pub mod api;
pub mod health;
pub mod invalidate;
21 changes: 18 additions & 3 deletions src/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,30 @@ use xxhash_rust::xxh3;
pub struct CacheKeyPayload {
pub method: String,
pub url: String,
pub host: String,
pub authorization_header: String,
pub authorization: Option<String>,
}

pub fn create_cache_key(method: &[u8], url: &[u8], authorization: Option<&[u8]>) -> i64 {
let mut buffer =
Vec::with_capacity(method.len() + url.len() + authorization.map_or(0, |x| x.len()));
pub fn create_cache_key(
method: &[u8],
url: &[u8],
host: &[u8],
authorization_header: &[u8],
authorization: Option<&[u8]>,
) -> i64 {
let mut buffer = Vec::with_capacity(
method.len()
+ url.len()
+ host.len()
+ authorization_header.len()
+ authorization.map_or(0, |x| x.len()),
);

buffer.extend_from_slice(method);
buffer.extend_from_slice(url);
buffer.extend_from_slice(host);
buffer.extend_from_slice(authorization_header);

if let Some(authorization) = authorization {
buffer.extend_from_slice(authorization);
Expand Down
12 changes: 4 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ pub mod db;
mod handlers;
pub mod hash;

use axum::{
routing::{get, post},
serve, Router,
};
use axum::{routing::any, serve, Router};
use db::DB;
use handlers::{api::api_handler, health::health_handler, invalidate::invalidate_handler};
use handlers::{api::api_handler, health::health_handler};
use tokio::net::TcpListener;
use tracing::{info, level_filters::LevelFilter};
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt};
Expand All @@ -31,9 +28,8 @@ async fn main() {
DB.seed().await;

let app = Router::new()
.route("/", get(health_handler))
.route("/api/*path", get(api_handler))
.route("/invalidate", post(invalidate_handler));
.route("/", any(health_handler))
.route("/api/*path", any(api_handler));

info!("listening on {}", BIND_ADDRESS);

Expand Down

0 comments on commit e49e580

Please sign in to comment.