Skip to content

Commit

Permalink
Cors middleware extension (#275)
Browse files Browse the repository at this point in the history
* feat: Add AllowOrigin enum

* fix build_preflight_request

* feat: Add to handle origin method

* feat: Add origin validation

* test: Add test code

* refactor: Rename AllowOrigin to CorsOrigin

* refactor: cargo fmt + fix clippy

* Update tide-cors/src/middleware.rs

Co-Authored-By: Yoshua Wuyts <[email protected]>

* Update tide-cors/src/middleware.rs

Co-Authored-By: Yoshua Wuyts <[email protected]>
  • Loading branch information
k-nasa and yoshuawuyts committed Jul 22, 2019
1 parent b77a242 commit 9a9b10a
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 20 deletions.
4 changes: 2 additions & 2 deletions examples/cors.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#![feature(async_await)]

use http::header::HeaderValue;
use tide::middleware::CorsMiddleware;
use tide::middleware::{CorsMiddleware, CorsOrigin};

fn main() {
let mut app = tide::App::new();

app.middleware(
CorsMiddleware::new()
.allow_origin(HeaderValue::from_static("*"))
.allow_origin(CorsOrigin::from("*"))
.allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")),
);

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub mod middleware {
pub use tide_log::RequestLogger;

#[cfg(feature = "cors")]
pub use tide_cors::CorsMiddleware;
pub use tide_cors::{CorsMiddleware, CorsOrigin};

#[cfg(feature = "cookies")]
pub use tide_cookies::CookiesMiddleware;
Expand Down
6 changes: 3 additions & 3 deletions tide-cors/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
//! #![feature(async_await)]
//!
//! use http::header::HeaderValue;
//! use tide_cors::CorsMiddleware;
//! use tide::middleware::{CorsMiddleware, CorsOrigin};
//!
//! fn main() {
//! let mut app = tide::App::new();
//!
//! app.middleware(
//! CorsMiddleware::new()
//! .allow_origin(HeaderValue::from_static("*"))
//! .allow_origin(CorsOrigin::from("*"))
//! .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")),
//! );
//!
Expand Down Expand Up @@ -41,4 +41,4 @@

mod middleware;

pub use self::middleware::CorsMiddleware;
pub use self::middleware::{CorsMiddleware, CorsOrigin};
166 changes: 152 additions & 14 deletions tide-cors/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use tide_core::{
///
/// ```rust
///use http::header::HeaderValue;
///use tide_cors::CorsMiddleware;
///use tide::middleware::{CorsOrigin, CorsMiddleware};
///
///CorsMiddleware::new()
/// .allow_origin(HeaderValue::from_static("*"))
/// .allow_origin(CorsOrigin::from("*"))
/// .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS"))
/// .allow_credentials(false);
/// ```
Expand All @@ -28,11 +28,53 @@ pub struct CorsMiddleware {
allow_credentials: Option<HeaderValue>,
allow_headers: HeaderValue,
allow_methods: HeaderValue,
allow_origin: HeaderValue,
allow_origin: CorsOrigin,
expose_headers: Option<HeaderValue>,
max_age: HeaderValue,
}

/// allow_origin enum
#[derive(Clone, Debug, Hash, PartialEq)]
pub enum CorsOrigin {
/// Wildcard. Accept all origin requests
Any,
/// Set a single allow_origin target
Exact(String),
/// Set multiple allow_origin targets
List(Vec<String>),
}

impl From<String> for CorsOrigin {
fn from(s: String) -> Self {
if s == "*" {
return CorsOrigin::Any;
}
CorsOrigin::Exact(s)
}
}

impl From<&str> for CorsOrigin {
fn from(s: &str) -> Self {
CorsOrigin::from(s.to_string())
}
}

impl From<Vec<String>> for CorsOrigin {
fn from(list: Vec<String>) -> Self {
if list.len() == 1 {
return Self::from(list[0].clone());
}

CorsOrigin::List(list)
}
}

impl From<Vec<&str>> for CorsOrigin {
fn from(list: Vec<&str>) -> Self {
CorsOrigin::from(list.iter().map(|s| s.to_string()).collect::<Vec<String>>())
}
}

pub const DEFAULT_MAX_AGE: &str = "86400";
pub const DEFAULT_METHODS: &str = "GET, POST, OPTIONS";
pub const WILDCARD: &str = "*";
Expand All @@ -44,7 +86,7 @@ impl CorsMiddleware {
allow_credentials: None,
allow_headers: HeaderValue::from_static(WILDCARD),
allow_methods: HeaderValue::from_static(DEFAULT_METHODS),
allow_origin: HeaderValue::from_static(WILDCARD),
allow_origin: CorsOrigin::Any,
expose_headers: None,
max_age: HeaderValue::from_static(DEFAULT_MAX_AGE),
}
Expand Down Expand Up @@ -78,7 +120,7 @@ impl CorsMiddleware {
}

/// Set allow_origin and return new CorsMiddleware
pub fn allow_origin<T: Into<HeaderValue>>(mut self, origin: T) -> Self {
pub fn allow_origin<T: Into<CorsOrigin>>(mut self, origin: T) -> Self {
self.allow_origin = origin.into();
self
}
Expand All @@ -89,13 +131,10 @@ impl CorsMiddleware {
self
}

fn build_preflight_response(&self) -> http::response::Response<Body> {
fn build_preflight_response(&self, origin: &HeaderValue) -> http::response::Response<Body> {
let mut response = http::Response::builder()
.status(StatusCode::OK)
.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.allow_origin.clone(),
)
.header::<_, HeaderValue>(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone())
.header(
header::ACCESS_CONTROL_ALLOW_METHODS,
self.allow_methods.clone(),
Expand All @@ -122,22 +161,65 @@ impl CorsMiddleware {

response
}

/// Look at origin of request and determine allow_origin
fn response_origin<T: Into<HeaderValue>>(&self, origin: T) -> Option<HeaderValue> {
let origin = origin.into();
if !self.is_valid_origin(origin.clone()) {
return None;
}

match self.allow_origin {
CorsOrigin::Any => Some(HeaderValue::from_static(WILDCARD)),
_ => Some(origin),
}
}

/// Determine if origin is appropriate
fn is_valid_origin<T: Into<HeaderValue>>(&self, origin: T) -> bool {
let origin = match origin.into().to_str() {
Ok(s) => s.to_string(),
Err(_) => return false,
};

match &self.allow_origin {
CorsOrigin::Any => true,
CorsOrigin::Exact(s) => s == &origin,
CorsOrigin::List(list) => list.contains(&origin),
}
}
}

impl<State: Send + Sync + 'static> Middleware<State> for CorsMiddleware {
fn handle<'a>(&'a self, cx: Context<State>, next: Next<'a, State>) -> BoxFuture<'a, Response> {
Box::pin(async move {
let origin = if let Some(origin) = cx.request().headers().get(header::ORIGIN) {
origin.clone()
} else {
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap();
};

if !self.is_valid_origin(&origin) {
return http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap();
}

// Return results immediately upon preflight request
if cx.method() == Method::OPTIONS {
return self.build_preflight_response();
return self.build_preflight_response(&origin);
}

let mut response = next.run(cx).await;
let headers = response.headers_mut();

headers.append(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.allow_origin.clone(),
self.response_origin(origin).unwrap(),
);

if let Some(allow_credentials) = self.allow_credentials.clone() {
Expand Down Expand Up @@ -187,7 +269,7 @@ mod test {
let mut app = app();
app.middleware(
CorsMiddleware::new()
.allow_origin(HeaderValue::from_static(ALLOW_ORIGIN))
.allow_origin(CorsOrigin::from(ALLOW_ORIGIN))
.allow_methods(HeaderValue::from_static(ALLOW_METHODS))
.expose_headers(HeaderValue::from_static(EXPOSE_HEADER))
.allow_credentials(true),
Expand Down Expand Up @@ -250,7 +332,7 @@ mod test {
let mut app = app();
app.middleware(
CorsMiddleware::new()
.allow_origin(HeaderValue::from_static(ALLOW_ORIGIN))
.allow_origin(CorsOrigin::from(ALLOW_ORIGIN))
.allow_credentials(false)
.allow_methods(HeaderValue::from_static(ALLOW_METHODS))
.expose_headers(HeaderValue::from_static(EXPOSE_HEADER)),
Expand Down Expand Up @@ -282,4 +364,60 @@ mod test {
"true"
);
}
#[test]
fn set_allow_origin_list() {
let mut app = app();
let origins = vec![ALLOW_ORIGIN, "foo.com", "bar.com"];
app.middleware(CorsMiddleware::new().allow_origin(origins.clone()));
let mut server = make_server(app.into_http_service()).unwrap();

for origin in origins {
let request = http::Request::get(ENDPOINT)
.header(http::header::ORIGIN, origin)
.method(http::method::Method::GET)
.body(Body::empty())
.unwrap();

let res = server.simulate(request).unwrap();

assert_eq!(res.status(), 200);
assert_eq!(
res.headers().get("access-control-allow-origin").unwrap(),
origin
);
}
}

#[test]
fn not_set_origin_header() {
let mut app = app();
app.middleware(CorsMiddleware::new());

let request = http::Request::get(ENDPOINT)
.method(http::method::Method::GET)
.body(Body::empty())
.unwrap();

let mut server = make_server(app.into_http_service()).unwrap();
let res = server.simulate(request).unwrap();

assert_eq!(res.status(), 400);
}

#[test]
fn unauthorized_origin() {
let mut app = app();
app.middleware(CorsMiddleware::new().allow_origin(ALLOW_ORIGIN));

let request = http::Request::get(ENDPOINT)
.header(http::header::ORIGIN, "unauthorize-origin.net")
.method(http::method::Method::GET)
.body(Body::empty())
.unwrap();

let mut server = make_server(app.into_http_service()).unwrap();
let res = server.simulate(request).unwrap();

assert_eq!(res.status(), 401);
}
}

0 comments on commit 9a9b10a

Please sign in to comment.