Skip to content

Commit

Permalink
improve forwarded header parsing. (#896)
Browse files Browse the repository at this point in the history
* improve forwarded header parsing.

* trim key name before compare.
  • Loading branch information
fakeshadow authored Jan 19, 2024
1 parent 8ad12a7 commit a1bce40
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions http-rate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ impl RateLimit {
}

/// Rate limit [Request] based on it's [HeaderMap] state and given client [SocketAddr]
/// "x-real-ip", "x-forwarded-for" and "forwarded" are checked in order start from left to
/// determine client's socket address. Received [SocketAddr] will be used as fallback when
/// all headers are absent or can't provide valid client address.
/// "x-real-ip", "x-forwarded-for" and "forwarded" headers are checked in order start
/// from left to determine client's socket address. Received [SocketAddr] will be used
/// as fallback when all headers are absent or can't provide valid client address.
///
/// [Request]: http::Request
pub fn rate_limit(&self, headers: &HeaderMap, addr: &SocketAddr) -> Result<RateSnapshot, TooManyRequests> {
Expand Down Expand Up @@ -66,28 +66,22 @@ fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
}

fn maybe_forwarded(headers: &HeaderMap) -> Option<IpAddr> {
let mut res = None;

for mut val in headers
headers
.get_all(FORWARDED)
.iter()
.filter_map(|h| h.to_str().ok())
.flat_map(|val| val.split(';'))
.flat_map(|p| p.split(','))
.map(|val| val.trim().splitn(2, '='))
{
if let (Some(name), Some(val)) = (val.next(), val.next()) {
if name.eq_ignore_ascii_case("for") {
.find_map(|mut val| match (val.next(), val.next()) {
(Some(name), Some(val)) if name.trim().eq_ignore_ascii_case("for") => {
let val = val.trim();
match val.parse::<SocketAddr>() {
Ok(addr) => res = Some(addr.ip()),
Err(_) => res = val.parse::<IpAddr>().ok(),
}
val.parse::<IpAddr>()
.or_else(|_| val.parse::<SocketAddr>().map(|addr| addr.ip()))
.ok()
}
}
}

res
_ => None,
})
}

#[cfg(test)]
Expand Down Expand Up @@ -117,7 +111,7 @@ mod test {
let mut headers = HeaderMap::new();
headers.insert(
FORWARDED,
HeaderValue::from_static("for=192.0.2.60;proto=http;by=203.0.113.43"),
HeaderValue::from_static("for =192.0.2.60;proto=http;by=203.0.113.43"),
);
assert_eq!(maybe_forwarded(&headers).unwrap().to_string(), "192.0.2.60");
}
Expand All @@ -143,7 +137,7 @@ mod test {
assert!(lb.check().is_ok());

clock.advance(ms);
assert!(lb.check().is_err(), "{:?}", lb);
assert!(lb.check().is_err(), "{lb:?}");
}

#[test]
Expand All @@ -168,7 +162,7 @@ mod test {
assert!(lb.check_n(one).unwrap().is_ok());

clock.advance(ms);
assert!(lb.check_n(one).unwrap().is_err(), "{:?}", lb);
assert!(lb.check_n(one).unwrap().is_err(), "{lb:?}");
}

#[test]
Expand Down

0 comments on commit a1bce40

Please sign in to comment.