diff --git a/src/handlers/api.rs b/src/handlers/api.rs index a1519e7..b7df539 100644 --- a/src/handlers/api.rs +++ b/src/handlers/api.rs @@ -2,11 +2,11 @@ use std::sync::LazyLock; use axum::{ extract::Request, - http::{HeaderValue, Uri}, + http::{Method, Uri}, response::Response, }; use reqwest::{ - header::{AUTHORIZATION, CONTENT_TYPE, HOST}, + header::{CONTENT_TYPE, HOST}, Client, }; @@ -17,9 +17,14 @@ use crate::{ pub static HTTP: LazyLock = LazyLock::new(Client::default); -pub const DISCORD_HOST: HeaderValue = HeaderValue::from_static("discord.com"); - pub async fn api_handler(request: Request) -> Response { + if request.method() != Method::GET && request.method() != Method::DELETE { + return Response::builder() + .status(405) + .body("Method Not Allowed".into()) + .unwrap(); + } + let (mut head, _) = request.into_parts(); let uri = head.uri.into_parts(); @@ -30,24 +35,51 @@ pub async fn api_handler(request: Request) -> Response { .unwrap(); }; + let host = head + .headers + .get("x-host") + .and_then(|host| host.to_str().ok()) + .unwrap_or("discord.com") + .to_lowercase(); + + let authorization_header = head + .headers + .get("x-authorization-name") + .map(|x| x.to_str().unwrap()) + .unwrap_or("authorization") + .to_lowercase(); + + let authorization = head + .headers + .get(&authorization_header) + .and_then(|x| x.to_str().ok()); + let cache_key = create_cache_key( head.method.as_str().as_bytes(), path.as_str().as_bytes(), - head.headers.get(AUTHORIZATION).map(|x| x.as_bytes()), + host.as_bytes(), + authorization_header.as_bytes(), + authorization.map(|x| x.as_bytes()), ); + if head.method == Method::DELETE { + DB.delete(cache_key).await; + + return Response::builder().status(200).body("OK".into()).unwrap(); + } + if let Some(cache_response) = DB.get(cache_key).await { return cache_response.into(); } let url = Uri::builder() - .authority("discord.com") + .authority(host.as_str()) .scheme("https") .path_and_query(path) .build() .unwrap(); - head.headers.insert(HOST, DISCORD_HOST); + head.headers.insert(HOST, host.parse().unwrap()); let mut response = HTTP .get(url.to_string()) diff --git a/src/handlers/invalidate.rs b/src/handlers/invalidate.rs deleted file mode 100644 index ac7feba..0000000 --- a/src/handlers/invalidate.rs +++ /dev/null @@ -1,16 +0,0 @@ -use axum::Json; - -use crate::{ - db::DB, - hash::{create_cache_key, CacheKeyPayload}, -}; - -pub async fn invalidate_handler(body: Json) { - let key = create_cache_key( - body.method.as_bytes(), - body.url.as_bytes(), - body.authorization.as_ref().map(|header| header.as_bytes()), - ); - - DB.delete(key).await; -} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index ba6614a..faf9f7c 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,3 +1,2 @@ pub mod api; pub mod health; -pub mod invalidate; diff --git a/src/hash.rs b/src/hash.rs index c52295d..00f8717 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -5,15 +5,30 @@ use xxhash_rust::xxh3; pub struct CacheKeyPayload { pub method: String, pub url: String, + pub host: String, + pub authorization_header: String, pub authorization: Option, } -pub fn create_cache_key(method: &[u8], url: &[u8], authorization: Option<&[u8]>) -> i64 { - let mut buffer = - Vec::with_capacity(method.len() + url.len() + authorization.map_or(0, |x| x.len())); +pub fn create_cache_key( + method: &[u8], + url: &[u8], + host: &[u8], + authorization_header: &[u8], + authorization: Option<&[u8]>, +) -> i64 { + let mut buffer = Vec::with_capacity( + method.len() + + url.len() + + host.len() + + authorization_header.len() + + authorization.map_or(0, |x| x.len()), + ); buffer.extend_from_slice(method); buffer.extend_from_slice(url); + buffer.extend_from_slice(host); + buffer.extend_from_slice(authorization_header); if let Some(authorization) = authorization { buffer.extend_from_slice(authorization); diff --git a/src/main.rs b/src/main.rs index 24d0ed9..88b62b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,12 +2,9 @@ pub mod db; mod handlers; pub mod hash; -use axum::{ - routing::{get, post}, - serve, Router, -}; +use axum::{routing::any, serve, Router}; use db::DB; -use handlers::{api::api_handler, health::health_handler, invalidate::invalidate_handler}; +use handlers::{api::api_handler, health::health_handler}; use tokio::net::TcpListener; use tracing::{info, level_filters::LevelFilter}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; @@ -31,9 +28,8 @@ async fn main() { DB.seed().await; let app = Router::new() - .route("/", get(health_handler)) - .route("/api/*path", get(api_handler)) - .route("/invalidate", post(invalidate_handler)); + .route("/", any(health_handler)) + .route("/api/*path", any(api_handler)); info!("listening on {}", BIND_ADDRESS);