|
| 1 | +use core::str; |
1 | 2 | use std::borrow::{Borrow, Cow};
|
2 | 3 | use std::fmt;
|
3 | 4 |
|
| 5 | +use headers::{Header as HHeader, HeaderValue}; |
4 | 6 | use indexmap::IndexMap;
|
5 | 7 |
|
6 | 8 | use crate::uncased::{Uncased, UncasedStr};
|
@@ -798,10 +800,153 @@ impl From<&cookie::Cookie<'_>> for Header<'static> {
|
798 | 800 | }
|
799 | 801 | }
|
800 | 802 |
|
| 803 | +/// A destination for `HeaderValue`s that can be used to accumulate |
| 804 | +/// a single header value using from hyperium headers' decode protocol. |
| 805 | +#[derive(Default)] |
| 806 | +struct HeaderValueDestination { |
| 807 | + value: Option<HeaderValue>, |
| 808 | + count: usize, |
| 809 | +} |
| 810 | + |
| 811 | +impl <'r>HeaderValueDestination { |
| 812 | + fn into_value(self) -> HeaderValue { |
| 813 | + if let Some(value) = self.value { |
| 814 | + // TODO: if value.count > 1, then log that multiple header values are |
| 815 | + // generated by the typed header, but that the dropped. |
| 816 | + value |
| 817 | + } else { |
| 818 | + // Perhaps log that the typed header didn't create any values. |
| 819 | + // This won't happen in the current implementation (headers 0.4.0). |
| 820 | + HeaderValue::from_static("") |
| 821 | + } |
| 822 | + } |
| 823 | + |
| 824 | + fn into_header_string(self) -> Cow<'static, str> { |
| 825 | + let value = self.into_value(); |
| 826 | + // TODO: Optimize if we know this is a static reference. |
| 827 | + value.to_str().unwrap_or("").to_string().into() |
| 828 | + } |
| 829 | +} |
| 830 | + |
| 831 | +impl Extend<HeaderValue> for HeaderValueDestination { |
| 832 | + fn extend<T: IntoIterator<Item = HeaderValue>>(&mut self, iter: T) { |
| 833 | + for value in iter { |
| 834 | + self.count += 1; |
| 835 | + if self.value.is_none() { |
| 836 | + self.value = Some(value) |
| 837 | + } |
| 838 | + } |
| 839 | + } |
| 840 | +} |
| 841 | + |
| 842 | +macro_rules! import_typed_headers { |
| 843 | +($($name:ident),*) => ($( |
| 844 | + pub use headers::$name; |
| 845 | + |
| 846 | + impl ::std::convert::From<self::$name> for Header<'static> { |
| 847 | + fn from(header: self::$name) -> Self { |
| 848 | + let mut destination = HeaderValueDestination::default(); |
| 849 | + header.encode(&mut destination); |
| 850 | + let name = self::$name::name(); |
| 851 | + Header::new(name.as_str(), destination.into_header_string()) |
| 852 | + } |
| 853 | + } |
| 854 | +)*) |
| 855 | +} |
| 856 | + |
| 857 | +macro_rules! import_generic_typed_headers { |
| 858 | +($($name:ident<$bound:ident>),*) => ($( |
| 859 | + pub use headers::$name; |
| 860 | + |
| 861 | + impl <T1: 'static + $bound>::std::convert::From<self::$name<T1>> |
| 862 | + for Header<'static> { |
| 863 | + fn from(header: self::$name<T1>) -> Self { |
| 864 | + let mut destination = HeaderValueDestination::default(); |
| 865 | + header.encode(&mut destination); |
| 866 | + let name = self::$name::<T1>::name(); |
| 867 | + Header::new(name.as_str(), destination.into_header_string()) |
| 868 | + } |
| 869 | + } |
| 870 | +)*) |
| 871 | +} |
| 872 | + |
| 873 | +// The following headers from 'headers' 0.4 are not imported, since they are |
| 874 | +// provided by other Rocket features. |
| 875 | + |
| 876 | +// * ContentType, // Content-Type header, defined in RFC7231 |
| 877 | +// * Cookie, // Cookie header, defined in RFC6265 |
| 878 | +// * Host, // The Host header. |
| 879 | +// * Location, // Location header, defined in RFC7231 |
| 880 | +// * SetCookie, // Set-Cookie header, defined RFC6265 |
| 881 | + |
| 882 | +import_typed_headers! { |
| 883 | + AcceptRanges, // Accept-Ranges header, defined in RFC7233 |
| 884 | + AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS |
| 885 | + AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS |
| 886 | + AccessControlAllowMethods, // Access-Control-Allow-Methods header, part of CORS |
| 887 | + AccessControlAllowOrigin, // The Access-Control-Allow-Origin response header, part of CORS |
| 888 | + AccessControlExposeHeaders, // Access-Control-Expose-Headers header, part of CORS |
| 889 | + AccessControlMaxAge, // Access-Control-Max-Age header, part of CORS |
| 890 | + AccessControlRequestHeaders, // Access-Control-Request-Headers header, part of CORS |
| 891 | + AccessControlRequestMethod, // Access-Control-Request-Method header, part of CORS |
| 892 | + Age, // Age header, defined in RFC7234 |
| 893 | + Allow, // Allow header, defined in RFC7231 |
| 894 | + CacheControl, // Cache-Control header, defined in RFC7234 with extensions in RFC8246 |
| 895 | + Connection, // Connection header, defined in RFC7230 |
| 896 | + ContentDisposition, // A Content-Disposition header, (re)defined in RFC6266. |
| 897 | + ContentEncoding, // Content-Encoding header, defined in RFC7231 |
| 898 | + ContentLength, // Content-Length header, defined in RFC7230 |
| 899 | + ContentLocation, // Content-Location header, defined in RFC7231 |
| 900 | + ContentRange, // Content-Range, described in RFC7233 |
| 901 | + Date, // Date header, defined in RFC7231 |
| 902 | + ETag, // ETag header, defined in RFC7232 |
| 903 | + Expect, // The Expect header. |
| 904 | + Expires, // Expires header, defined in RFC7234 |
| 905 | + IfMatch, // If-Match header, defined in RFC7232 |
| 906 | + IfModifiedSince, // If-Modified-Since header, defined in RFC7232 |
| 907 | + IfNoneMatch, // If-None-Match header, defined in RFC7232 |
| 908 | + IfRange, // If-Range header, defined in RFC7233 |
| 909 | + IfUnmodifiedSince, // If-Unmodified-Since header, defined in RFC7232 |
| 910 | + LastModified, // Last-Modified header, defined in RFC7232 |
| 911 | + Origin, // The Origin header. |
| 912 | + Pragma, // The Pragma header defined by HTTP/1.0. |
| 913 | + Range, // Range header, defined in RFC7233 |
| 914 | + Referer, // Referer header, defined in RFC7231 |
| 915 | + ReferrerPolicy, // Referrer-Policy header, part of Referrer Policy |
| 916 | + RetryAfter, // The Retry-After header. |
| 917 | + SecWebsocketAccept, // The Sec-Websocket-Accept header. |
| 918 | + SecWebsocketKey, // The Sec-Websocket-Key header. |
| 919 | + SecWebsocketVersion, // The Sec-Websocket-Version header. |
| 920 | + Server, // Server header, defined in RFC7231 |
| 921 | + StrictTransportSecurity, // StrictTransportSecurity header, defined in RFC6797 |
| 922 | + Te, // TE header, defined in RFC7230 |
| 923 | + TransferEncoding, // Transfer-Encoding header, defined in RFC7230 |
| 924 | + Upgrade, // Upgrade header, defined in RFC7230 |
| 925 | + UserAgent, // User-Agent header, defined in RFC7231 |
| 926 | + Vary // Vary header, defined in RFC7231 |
| 927 | +} |
| 928 | + |
| 929 | +import_generic_typed_headers! { |
| 930 | + Authorization<Credentials>, // Authorization header, defined in RFC7235 |
| 931 | + ProxyAuthorization<Credentials> // Proxy-Authorization header, defined in RFC7235 |
| 932 | +} |
| 933 | + |
| 934 | +pub use headers::authorization::Credentials; |
| 935 | + |
801 | 936 | #[cfg(test)]
|
802 | 937 | mod tests {
|
| 938 | + use std::time::SystemTime; |
| 939 | + |
803 | 940 | use super::HeaderMap;
|
804 | 941 |
|
| 942 | + #[test] |
| 943 | + fn add_typed_header() { |
| 944 | + use super::LastModified; |
| 945 | + let mut map = HeaderMap::new(); |
| 946 | + map.add(LastModified::from(SystemTime::now())); |
| 947 | + assert!(map.get_one("last-modified").unwrap().contains("GMT")); |
| 948 | + } |
| 949 | + |
805 | 950 | #[test]
|
806 | 951 | fn case_insensitive_add_get() {
|
807 | 952 | let mut map = HeaderMap::new();
|
|
0 commit comments