From 9a9b10a6fa7ad3aebbebd7a2404ac4ddc4c17ab7 Mon Sep 17 00:00:00 2001 From: nasa Date: Mon, 22 Jul 2019 18:54:05 +0900 Subject: [PATCH] Cors middleware extension (#275) * feat: Add AllowOrigin enum * fix build_preflight_request * feat: Add to handle origin method * feat: Add origin validation * test: Add test code * refactor: Rename AllowOrigin to CorsOrigin * refactor: cargo fmt + fix clippy * Update tide-cors/src/middleware.rs Co-Authored-By: Yoshua Wuyts * Update tide-cors/src/middleware.rs Co-Authored-By: Yoshua Wuyts --- examples/cors.rs | 4 +- src/lib.rs | 2 +- tide-cors/src/lib.rs | 6 +- tide-cors/src/middleware.rs | 166 +++++++++++++++++++++++++++++++++--- 4 files changed, 158 insertions(+), 20 deletions(-) diff --git a/examples/cors.rs b/examples/cors.rs index f32afd271..a62fc0450 100644 --- a/examples/cors.rs +++ b/examples/cors.rs @@ -1,14 +1,14 @@ #![feature(async_await)] use http::header::HeaderValue; -use tide::middleware::CorsMiddleware; +use tide::middleware::{CorsMiddleware, CorsOrigin}; fn main() { let mut app = tide::App::new(); app.middleware( CorsMiddleware::new() - .allow_origin(HeaderValue::from_static("*")) + .allow_origin(CorsOrigin::from("*")) .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")), ); diff --git a/src/lib.rs b/src/lib.rs index 894515225..713c99778 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,7 @@ pub mod middleware { pub use tide_log::RequestLogger; #[cfg(feature = "cors")] - pub use tide_cors::CorsMiddleware; + pub use tide_cors::{CorsMiddleware, CorsOrigin}; #[cfg(feature = "cookies")] pub use tide_cookies::CookiesMiddleware; diff --git a/tide-cors/src/lib.rs b/tide-cors/src/lib.rs index b7eaf44a8..17356897a 100644 --- a/tide-cors/src/lib.rs +++ b/tide-cors/src/lib.rs @@ -6,14 +6,14 @@ //! #![feature(async_await)] //! //! use http::header::HeaderValue; -//! use tide_cors::CorsMiddleware; +//! use tide::middleware::{CorsMiddleware, CorsOrigin}; //! //! fn main() { //! let mut app = tide::App::new(); //! //! app.middleware( //! CorsMiddleware::new() -//! .allow_origin(HeaderValue::from_static("*")) +//! .allow_origin(CorsOrigin::from("*")) //! .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")), //! ); //! @@ -41,4 +41,4 @@ mod middleware; -pub use self::middleware::CorsMiddleware; +pub use self::middleware::{CorsMiddleware, CorsOrigin}; diff --git a/tide-cors/src/middleware.rs b/tide-cors/src/middleware.rs index 7027034f0..2b979564e 100644 --- a/tide-cors/src/middleware.rs +++ b/tide-cors/src/middleware.rs @@ -16,10 +16,10 @@ use tide_core::{ /// /// ```rust ///use http::header::HeaderValue; -///use tide_cors::CorsMiddleware; +///use tide::middleware::{CorsOrigin, CorsMiddleware}; /// ///CorsMiddleware::new() -/// .allow_origin(HeaderValue::from_static("*")) +/// .allow_origin(CorsOrigin::from("*")) /// .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")) /// .allow_credentials(false); /// ``` @@ -28,11 +28,53 @@ pub struct CorsMiddleware { allow_credentials: Option, allow_headers: HeaderValue, allow_methods: HeaderValue, - allow_origin: HeaderValue, + allow_origin: CorsOrigin, expose_headers: Option, max_age: HeaderValue, } +/// allow_origin enum +#[derive(Clone, Debug, Hash, PartialEq)] +pub enum CorsOrigin { + /// Wildcard. Accept all origin requests + Any, + /// Set a single allow_origin target + Exact(String), + /// Set multiple allow_origin targets + List(Vec), +} + +impl From for CorsOrigin { + fn from(s: String) -> Self { + if s == "*" { + return CorsOrigin::Any; + } + CorsOrigin::Exact(s) + } +} + +impl From<&str> for CorsOrigin { + fn from(s: &str) -> Self { + CorsOrigin::from(s.to_string()) + } +} + +impl From> for CorsOrigin { + fn from(list: Vec) -> Self { + if list.len() == 1 { + return Self::from(list[0].clone()); + } + + CorsOrigin::List(list) + } +} + +impl From> for CorsOrigin { + fn from(list: Vec<&str>) -> Self { + CorsOrigin::from(list.iter().map(|s| s.to_string()).collect::>()) + } +} + pub const DEFAULT_MAX_AGE: &str = "86400"; pub const DEFAULT_METHODS: &str = "GET, POST, OPTIONS"; pub const WILDCARD: &str = "*"; @@ -44,7 +86,7 @@ impl CorsMiddleware { allow_credentials: None, allow_headers: HeaderValue::from_static(WILDCARD), allow_methods: HeaderValue::from_static(DEFAULT_METHODS), - allow_origin: HeaderValue::from_static(WILDCARD), + allow_origin: CorsOrigin::Any, expose_headers: None, max_age: HeaderValue::from_static(DEFAULT_MAX_AGE), } @@ -78,7 +120,7 @@ impl CorsMiddleware { } /// Set allow_origin and return new CorsMiddleware - pub fn allow_origin>(mut self, origin: T) -> Self { + pub fn allow_origin>(mut self, origin: T) -> Self { self.allow_origin = origin.into(); self } @@ -89,13 +131,10 @@ impl CorsMiddleware { self } - fn build_preflight_response(&self) -> http::response::Response { + fn build_preflight_response(&self, origin: &HeaderValue) -> http::response::Response { let mut response = http::Response::builder() .status(StatusCode::OK) - .header( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.allow_origin.clone(), - ) + .header::<_, HeaderValue>(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()) .header( header::ACCESS_CONTROL_ALLOW_METHODS, self.allow_methods.clone(), @@ -122,14 +161,57 @@ impl CorsMiddleware { response } + + /// Look at origin of request and determine allow_origin + fn response_origin>(&self, origin: T) -> Option { + let origin = origin.into(); + if !self.is_valid_origin(origin.clone()) { + return None; + } + + match self.allow_origin { + CorsOrigin::Any => Some(HeaderValue::from_static(WILDCARD)), + _ => Some(origin), + } + } + + /// Determine if origin is appropriate + fn is_valid_origin>(&self, origin: T) -> bool { + let origin = match origin.into().to_str() { + Ok(s) => s.to_string(), + Err(_) => return false, + }; + + match &self.allow_origin { + CorsOrigin::Any => true, + CorsOrigin::Exact(s) => s == &origin, + CorsOrigin::List(list) => list.contains(&origin), + } + } } impl Middleware for CorsMiddleware { fn handle<'a>(&'a self, cx: Context, next: Next<'a, State>) -> BoxFuture<'a, Response> { Box::pin(async move { + let origin = if let Some(origin) = cx.request().headers().get(header::ORIGIN) { + origin.clone() + } else { + return http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap(); + }; + + if !self.is_valid_origin(&origin) { + return http::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap(); + } + // Return results immediately upon preflight request if cx.method() == Method::OPTIONS { - return self.build_preflight_response(); + return self.build_preflight_response(&origin); } let mut response = next.run(cx).await; @@ -137,7 +219,7 @@ impl Middleware for CorsMiddleware { headers.append( header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.allow_origin.clone(), + self.response_origin(origin).unwrap(), ); if let Some(allow_credentials) = self.allow_credentials.clone() { @@ -187,7 +269,7 @@ mod test { let mut app = app(); app.middleware( CorsMiddleware::new() - .allow_origin(HeaderValue::from_static(ALLOW_ORIGIN)) + .allow_origin(CorsOrigin::from(ALLOW_ORIGIN)) .allow_methods(HeaderValue::from_static(ALLOW_METHODS)) .expose_headers(HeaderValue::from_static(EXPOSE_HEADER)) .allow_credentials(true), @@ -250,7 +332,7 @@ mod test { let mut app = app(); app.middleware( CorsMiddleware::new() - .allow_origin(HeaderValue::from_static(ALLOW_ORIGIN)) + .allow_origin(CorsOrigin::from(ALLOW_ORIGIN)) .allow_credentials(false) .allow_methods(HeaderValue::from_static(ALLOW_METHODS)) .expose_headers(HeaderValue::from_static(EXPOSE_HEADER)), @@ -282,4 +364,60 @@ mod test { "true" ); } + #[test] + fn set_allow_origin_list() { + let mut app = app(); + let origins = vec![ALLOW_ORIGIN, "foo.com", "bar.com"]; + app.middleware(CorsMiddleware::new().allow_origin(origins.clone())); + let mut server = make_server(app.into_http_service()).unwrap(); + + for origin in origins { + let request = http::Request::get(ENDPOINT) + .header(http::header::ORIGIN, origin) + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap(); + + let res = server.simulate(request).unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!( + res.headers().get("access-control-allow-origin").unwrap(), + origin + ); + } + } + + #[test] + fn not_set_origin_header() { + let mut app = app(); + app.middleware(CorsMiddleware::new()); + + let request = http::Request::get(ENDPOINT) + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap(); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request).unwrap(); + + assert_eq!(res.status(), 400); + } + + #[test] + fn unauthorized_origin() { + let mut app = app(); + app.middleware(CorsMiddleware::new().allow_origin(ALLOW_ORIGIN)); + + let request = http::Request::get(ENDPOINT) + .header(http::header::ORIGIN, "unauthorize-origin.net") + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap(); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request).unwrap(); + + assert_eq!(res.status(), 401); + } }