From 0bd425ff11c2618fcde90e610cdb6c59d64b75fd Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Tue, 11 Feb 2025 15:37:55 -0700 Subject: [PATCH] Implement caching for JWKSet --- Cargo.lock | 91 ++++++- edb/server/http.py | 31 ++- edb/server/protocol/auth_ext/base.py | 7 +- edb/server/tenant.py | 2 +- pyproject.toml | 1 - rust/http/Cargo.toml | 4 + rust/http/src/cache.rs | 340 +++++++++++++++++++++++++++ rust/http/src/lib.rs | 1 + rust/http/src/python.rs | 185 ++++++++++----- tests/test_http_ext_auth.py | 4 +- 10 files changed, 595 insertions(+), 71 deletions(-) create mode 100644 rust/http/src/cache.rs diff --git a/Cargo.lock b/Cargo.lock index 5bb774f79513..04c612812eeb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,7 +450,7 @@ dependencies = [ "futures", "genetic_algorithm", "itertools 0.13.0", - "lru", + "lru 0.12.5", "pretty_assertions", "pyo3", "pyo3_util", @@ -635,6 +635,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive_more" version = "2.0.1" @@ -1140,7 +1149,7 @@ dependencies = [ "serde_derive", "serde_json", "sha2", - "thiserror 2.0.3", + "thiserror 2.0.11", "tracing", "uuid", "zeroize", @@ -1376,6 +1385,10 @@ version = "0.1.0" dependencies = [ "eventsource-stream", "futures", + "http 1.1.0", + "http-body-util", + "http-cache-semantics", + "lru 0.13.0", "pyo3", "pyo3_util", "reqwest", @@ -1419,6 +1432,28 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-cache-semantics" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92baf25cf0b8c9246baecf3a444546360a97b569168fdf92563ee6a47829920c" +dependencies = [ + "http 1.1.0", + "http-serde", + "serde", + "time", +] + +[[package]] +name = "http-serde" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd" +dependencies = [ + "http 1.1.0", + "serde", +] + [[package]] name = "httparse" version = "1.9.5" @@ -1799,6 +1834,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "lru" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "227748d55f2f0ab4735d87fd623798cb6b664512fe979705f829c9f81c934465" +dependencies = [ + "hashbrown", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -2044,6 +2088,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -2337,6 +2387,12 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -3469,6 +3525,37 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.7.6" diff --git a/edb/server/http.py b/edb/server/http.py index ed6ef329dba6..07bb0f0056fc 100644 --- a/edb/server/http.py +++ b/edb/server/http.py @@ -180,6 +180,7 @@ async def request( headers: HeaderType = None, data: bytes | str | dict[str, str] | None = None, json: Any | None = None, + cache: bool = False, ) -> tuple[int, bytearray, dict[str, str]]: self._ensure_task() path = self._process_path(path) @@ -191,7 +192,9 @@ async def request( self._requests[id] = asyncio.Future() start_time = time.monotonic() try: - self._ensure_client()._request(id, path, method, data, headers_list) + self._ensure_client()._request( + id, path, method, data, headers_list, cache + ) resp = await self._requests[id] if self._stat_callback: status_code, body, headers = resp @@ -217,9 +220,15 @@ async def request( finally: del self._requests[id] - async def get(self, path: str, *, headers: HeaderType = None) -> Response: + async def get( + self, + path: str, + *, + headers: HeaderType = None, + cache: bool = False, + ) -> Response: result = await self.request( - method="GET", path=path, data=None, headers=headers + method="GET", path=path, data=None, headers=headers, cache=cache ) return Response.from_tuple(result) @@ -387,12 +396,24 @@ def _process_path(self, path): return path async def request( - self, *, method, path, headers=None, data=None, json=None + self, + *, + method, + path, + headers=None, + data=None, + json=None, + cache=False, ): path = self._process_path(path) headers = self._process_headers(headers) return await self.http_client.request( - method=method, path=path, headers=headers, data=data, json=json + method=method, + path=path, + headers=headers, + data=data, + json=json, + cache=cache, ) async def stream_sse( diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index d43a492e4a61..09834ae440ff 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -151,7 +151,7 @@ async def fetch_user_info( async with self.http_factory( base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}" ) as client: - r = await client.get(jwks_uri.path) + r = await client.get(jwks_uri.path, cache=True) # Load the token as a JWT object and verify it directly try: @@ -178,6 +178,9 @@ async def fetch_user_info( async def _get_oidc_config(self) -> data.OpenIDConfig: client = self.http_factory(base_url=self.issuer_url) - response = await client.get('/.well-known/openid-configuration') + response = await client.get( + '/.well-known/openid-configuration', + cache=True + ) config = response.json() return data.OpenIDConfig(**config) diff --git a/edb/server/tenant.py b/edb/server/tenant.py index bbdb0939e695..b53db0cbad44 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -300,7 +300,7 @@ def get_http_client(self, *, originator: str) -> HttpClient: user_agent=f"EdgeDB {buildmeta.get_version_string(short=True)}", stat_callback=lambda stat: logger.debug( f"HTTP stat: {originator} {stat}" - ), + ) ) return self._http_client diff --git a/pyproject.toml b/pyproject.toml index 6d601b285b3f..6d8474af0023 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ 'psutil~=5.8', 'setproctitle~=1.2', - 'hishel==0.0.24', 'webauthn~=2.0.0', 'argon2-cffi~=23.1.0', 'aiosmtplib~=3.0', diff --git a/rust/http/Cargo.toml b/rust/http/Cargo.toml index 62e84e0e1a89..e0e369f439c9 100644 --- a/rust/http/Cargo.toml +++ b/rust/http/Cargo.toml @@ -19,6 +19,10 @@ tracing.workspace = true scopeguard = "1" eventsource-stream = "0.2.3" +http-cache-semantics = { version = "2", features = [] } +http = "1" +http-body-util = "0.1.2" +lru = "0.13" # We want to use rustls to avoid setenv issues w/ OpenSSL and the system certs. As long # as we don't call `openssl_probe::*init*env*()` functions (functions that call setenv diff --git a/rust/http/src/cache.rs b/rust/http/src/cache.rs new file mode 100644 index 000000000000..7a3192948496 --- /dev/null +++ b/rust/http/src/cache.rs @@ -0,0 +1,340 @@ +use http::{HeaderMap, Method, Request, Response, Uri}; +use http_cache_semantics::{AfterResponse, BeforeRequest, CacheOptions, CachePolicy}; +use lru::LruCache; +use std::{ + num::NonZero, + sync::{Arc, Mutex}, + time::{Duration, SystemTime}, +}; + +#[derive(Debug, Clone)] +pub enum CacheBefore { + Request(http::Request>), + Response(http::Response>), +} + +struct CacheItems { + items: LruCache)>, + byte_size: usize, + max_byte_size: usize, +} + +impl CacheItems { + fn new(capacity: NonZero, max_byte_size: usize) -> Self { + Self { + items: LruCache::new(capacity), + byte_size: 0, + max_byte_size, + } + } + + fn insert(&mut self, uri: Uri, policy: T, body: Vec) { + let body_len = body.len(); + if let Some((_, old_body)) = self.items.push(uri, (policy, body)) { + self.byte_size = self.byte_size.saturating_sub(old_body.1.len()); + } + self.byte_size = self.byte_size.saturating_add(body_len); + while self.byte_size > self.max_byte_size { + if let Some((_, old_body)) = self.items.pop_lru() { + self.byte_size = self.byte_size.saturating_sub(old_body.1.len()); + } else { + return; + } + } + } + + fn get_mut(&mut self, uri: &Uri) -> Option<&mut (T, Vec)> { + self.items.get_mut(uri) + } +} + +#[derive(Clone)] +pub struct Cache { + cache_options: CacheOptions, + cache: Arc>>, +} + +impl Cache { + pub fn new() -> Self { + Self { + cache_options: CacheOptions { + shared: false, + // Immutable objects should be cached for 24 hours + immutable_min_time_to_live: Duration::from_secs(86_400), + ..Default::default() + }, + cache: Arc::new(Mutex::new(CacheItems::new( + NonZero::new(100).unwrap(), + 1024 * 1024, + ))), + } + } + + #[cfg(test)] + pub fn get_cache_body(&self, url: &Uri) -> Option<(bool, Vec)> { + let mut cache = self.cache.lock().unwrap(); + let entry = cache.get_mut(url); + if let Some((policy, body)) = entry { + let state = policy.is_stale(SystemTime::now()); + return Some((state, body.clone())); + } + None + } + + pub fn before_request( + &self, + allow_cache: bool, + method: &Method, + url: &Uri, + headers: &HeaderMap, + body: Vec, + ) -> CacheBefore { + let mut req = Request::new(body); + *req.method_mut() = method.clone(); + *req.uri_mut() = url.clone(); + *req.headers_mut() = headers.clone(); + + // Only cache GET requests + if !allow_cache || method != Method::GET { + return CacheBefore::Request(req); + } + + let now = SystemTime::now(); + let mut cache = self.cache.lock().unwrap(); + if let Some((policy, body)) = cache.get_mut(url) { + match policy.before_request(&req, now) { + BeforeRequest::Fresh(parts) => { + // Fresh response from cache + CacheBefore::Response(Response::from_parts(parts, body.clone())) + } + BeforeRequest::Stale { request, .. } => { + *req.uri_mut() = request.uri; + *req.headers_mut() = request.headers; + *req.method_mut() = request.method; + CacheBefore::Request(req) + } + } + } else { + CacheBefore::Request(req) + } + } + + pub fn after_request( + &self, + allow_cache: bool, + method: Method, + uri: Uri, + headers: HeaderMap, + res: &mut http::Response>, + ) { + // Only cache GET requests + if !allow_cache || method != Method::GET { + return; + } + + let now = SystemTime::now(); + let mut cache = self.cache.lock().unwrap(); + let entry = cache.get_mut(&uri); + + let mut req = Request::new(()); + *req.method_mut() = method; + *req.uri_mut() = uri.clone(); + *req.headers_mut() = headers; + + let mut resp = Response::new(vec![]); + *resp.status_mut() = res.status(); + *resp.headers_mut() = res.headers().clone(); + + if let Some((policy, body)) = entry { + let parts = match policy.after_response(&req, &resp, now) { + AfterResponse::NotModified(new_policy, parts) => { + // Not modified, return the cached response + *policy = new_policy; + *resp.body_mut() = body.clone(); + parts + } + AfterResponse::Modified(new_policy, parts) => { + // Modified, update the cache + *policy = new_policy; + *body = res.body().clone(); + parts + } + }; + *resp.headers_mut() = parts.headers; + *resp.status_mut() = parts.status; + *resp.version_mut() = parts.version; + } else { + let policy = CachePolicy::new_options(&req, &resp, now, self.cache_options); + if policy.is_storable() { + cache.insert(uri, policy, res.body().clone()); + } + } + } +} + +#[cfg(test)] +mod tests { + use http::*; + use std::str::FromStr; + + use super::*; + + fn get_google() -> (Method, Uri, HeaderMap, Vec) { + let method = Method::GET; + let uri = Uri::from_str("https://www.google.com").unwrap(); + let headers = HeaderMap::new(); + let body = vec![]; + (method, uri, headers, body) + } + + fn cache_control(resp: &mut Response>, value: &str) { + resp.headers_mut().insert( + HeaderName::from_static("cache-control"), + HeaderValue::from_str(value).unwrap(), + ); + } + + fn etag(resp: &mut Response>, value: &str) { + resp.headers_mut().insert( + HeaderName::from_static("etag"), + HeaderValue::from_str(value).unwrap(), + ); + } + + fn response(status: StatusCode, body: &str) -> Response> { + let mut resp = Response::new(body.as_bytes().to_vec()); + *resp.status_mut() = status; + resp + } + + #[test] + fn test_cache_byte_size_eviction() { + let mut cache_items = CacheItems::<()>::new(NonZero::new(100).unwrap(), 1024 * 1024); + cache_items.insert( + Uri::from_str("https://www.google.com").unwrap(), + (), + vec![0; 1024 * 1024], + ); + assert_eq!(cache_items.byte_size, 1024 * 1024); + assert_eq!(cache_items.items.len(), 1); + cache_items.insert( + Uri::from_str("https://www.example.com").unwrap(), + (), + vec![0; 1], + ); + assert_eq!(cache_items.byte_size, 1); + assert_eq!(cache_items.items.len(), 1); + cache_items.insert( + Uri::from_str("https://www.google.com").unwrap(), + (), + vec![0; 1024 * 1024], + ); + assert_eq!(cache_items.byte_size, 1024 * 1024); + assert_eq!(cache_items.items.len(), 1); + } + + #[test] + fn test_cache_capacity_eviction() { + let mut cache_items = CacheItems::<()>::new(NonZero::new(100).unwrap(), 1024 * 1024); + for i in 0..120 { + cache_items.insert( + Uri::from_str(&format!("https://www.example.com/{}", i)).unwrap(), + (), + vec![0; 10], + ); + } + assert_eq!(cache_items.byte_size, 1000); + assert_eq!(cache_items.items.len(), 100); + } + + #[test] + fn test_cache() { + let cache = Cache::new(); + let (method, uri, headers, body) = get_google(); + let before = cache.before_request(true, &method, &uri, &headers, body); + assert!(matches!(before, CacheBefore::Request(_))); + + let mut resp = response(StatusCode::OK, ""); + cache_control(&mut resp, "max-age=3600"); + etag(&mut resp, "\"1234567890\""); + cache.after_request( + true, + method.clone(), + uri.clone(), + headers.clone(), + &mut resp, + ); + + let (method, uri, headers, body) = get_google(); + let after = cache.before_request(true, &method, &uri, &headers, body); + + let CacheBefore::Response(resp) = after else { + panic!("Expected a response {after:?}"); + }; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get("etag"), + Some(&HeaderValue::from_str("\"1234567890\"").unwrap()) + ); + assert_eq!( + resp.headers().get("cache-control"), + Some(&HeaderValue::from_str("max-age=3600").unwrap()) + ); + } + + #[test] + fn test_cache_not_modified() { + let cache = Cache::new(); + let (method, uri, headers, body) = get_google(); + let before = cache.before_request(true, &method, &uri, &headers, body); + assert!(matches!(before, CacheBefore::Request(_))); + + let mut resp = response(StatusCode::OK, "contents!"); + cache_control( + &mut resp, + "max-age=0, must-revalidate, stale-while-revalidate=86400", + ); + etag(&mut resp, "\"1234567890\""); + cache.after_request( + true, + method.clone(), + uri.clone(), + headers.clone(), + &mut resp, + ); + + let (state, body) = cache.get_cache_body(&uri).unwrap(); + assert_eq!(state, true); + assert_eq!(body, "contents!".as_bytes()); + + let (method, uri, headers, body) = get_google(); + let after = cache.before_request(true, &method, &uri, &headers, body); + let CacheBefore::Request(req) = after else { + panic!("Expected a request {after:?}"); + }; + assert_eq!(req.method(), &Method::GET); + assert_eq!(req.uri(), &uri); + assert_eq!( + req.headers().get("if-none-match"), + Some(&HeaderValue::from_str("\"1234567890\"").unwrap()) + ); + + let mut resp = response(StatusCode::NOT_MODIFIED, ""); + cache_control( + &mut resp, + "max-age=0, must-revalidate, stale-while-revalidate=86400", + ); + etag(&mut resp, "\"1234567890\""); + cache.after_request( + true, + method.clone(), + uri.clone(), + headers.clone(), + &mut resp, + ); + + let (state, body) = cache.get_cache_body(&uri).unwrap(); + assert_eq!(state, true); + assert_eq!(body, "contents!".as_bytes()); + } +} diff --git a/rust/http/src/lib.rs b/rust/http/src/lib.rs index 71429ec8ff93..0f1e17c95343 100644 --- a/rust/http/src/lib.rs +++ b/rust/http/src/lib.rs @@ -1,2 +1,3 @@ +mod cache; #[cfg(feature = "python_extension")] pub mod python; diff --git a/rust/http/src/python.rs b/rust/http/src/python.rs index 03832f2b00b9..6894edc8cc28 100644 --- a/rust/http/src/python.rs +++ b/rust/http/src/python.rs @@ -1,5 +1,7 @@ use eventsource_stream::Eventsource; use futures::{future::poll_fn, TryStreamExt}; +use http::{HeaderMap, HeaderName, HeaderValue, Uri}; +use http_body_util::BodyExt; use pyo3::{exceptions::PyException, prelude::*, types::PyByteArray}; use pyo3_util::logging::{get_python_logger_level, initialize_logging_in_thread}; use reqwest::Method; @@ -10,6 +12,7 @@ use std::{ os::fd::IntoRawFd, pin::Pin, rc::Rc, + str::FromStr, sync::{Arc, Mutex}, thread, time::Duration, @@ -21,6 +24,8 @@ use tokio::{ }; use tracing::{error, info, trace}; +use crate::cache::{Cache, CacheBefore}; + pyo3::create_exception!(_http, InternalError, PyException); /// The backlog for SSE message @@ -60,7 +65,14 @@ enum PythonToRustMessage { /// Update the inflight limit UpdateLimit(usize), /// Perform a request - Request(PythonConnId, String, String, Vec, Vec<(String, String)>), + Request( + PythonConnId, + String, + String, + Vec, + Vec<(String, String)>, + bool, + ), /// Perform a request with SSE RequestSse(PythonConnId, String, String, Vec, Vec<(String, String)>), /// Close an SSE connection @@ -117,32 +129,87 @@ fn internal_error(message: &str) -> PyErr { InternalError::new_err(()) } +/// If this is likely a stream, returns the `Stream` variant. +/// Otherwise, returns the `Bytes` variant. +enum MaybeResponse { + Bytes(Vec), + Stream(reqwest::Body), +} + +impl MaybeResponse { + async fn try_into_bytes(self) -> Result, String> { + match self { + MaybeResponse::Bytes(bytes) => Ok(bytes), + MaybeResponse::Stream(body) => Ok(http_body_util::BodyExt::collect(body) + .await + .map_err(|e| format!("Failed to read response body: {e:?}"))? + .to_bytes() + .to_vec()), + } + } +} + async fn request( client: reqwest::Client, url: String, method: String, body: Vec, headers: Vec<(String, String)>, -) -> Result { + allow_cache: bool, + cache: Cache, +) -> Result, String> { + let headers = parse_headers(headers)?; let method = Method::from_bytes(method.as_bytes()).map_err(|e| format!("Invalid HTTP method: {e:?}"))?; + let uri = Uri::from_str(&url).map_err(|e| format!("Invalid URL: {e:?}"))?; - let mut req = client.request(method, url); - - for (key, value) in headers { - req = req.header(key, value); - } - - if !body.is_empty() { - req = req.body(body); - } + let req = match cache.before_request(allow_cache, &method, &uri, &headers, body) { + CacheBefore::Request(req) => req, + CacheBefore::Response(resp) => { + return Ok(resp.map(MaybeResponse::Bytes)); + } + }; - let resp = req - .send() + let resp = client + .execute( + req.try_into() + .map_err(|e| format!("Invalid request: {e:?}"))?, + ) .await .map_err(|e| format!("Request failed: {e:?}"))?; + let resp: http::Response<_> = resp.into(); + + let content_type = resp.headers().get("content-type"); + let is_event_stream = content_type + .and_then(|v| v.to_str().ok()) + .map(|s| s.starts_with("text/event-stream")) + .unwrap_or(false); + + let mut resp = if is_event_stream { + return Ok(resp.map(MaybeResponse::Stream)); + } else { + let (parts, body) = resp.into_parts(); + let bytes = http_body_util::BodyExt::collect(body) + .await + .map_err(|e| format!("Failed to read response body: {e:?}"))? + .to_bytes(); + http::Response::from_parts(parts, bytes.to_vec()) + }; + + cache.after_request(allow_cache, method, uri, headers, &mut resp); - Ok(resp) + Ok(resp.map(MaybeResponse::Bytes)) +} + +fn parse_headers(headers: Vec<(String, String)>) -> Result { + let mut header_map = HeaderMap::new(); + for (key, value) in headers { + header_map.insert( + HeaderName::from_str(&key).map_err(|e| format!("Invalid header name: {e:?}"))?, + HeaderValue::from_str(&value).map_err(|e| format!("Invalid header value: {e:?}"))?, + ); + } + Ok(header_map) } async fn request_bytes( @@ -151,15 +218,15 @@ async fn request_bytes( method: String, body: Vec, headers: Vec<(String, String)>, -) -> Result<(reqwest::StatusCode, Vec, HashMap), String> { - let resp = request(client, url, method, body, headers).await?; - let status = resp.status(); - let headers = process_headers(resp.headers()); - let body = resp - .bytes() - .await - .map_err(|e| format!("Failed to read response body: {e:?}"))? - .to_vec(); + allow_cache: bool, + cache: Cache, +) -> Result<(http::StatusCode, Vec, HashMap), String> { + let (parts, body) = request(client, url, method, body, headers, allow_cache, cache) + .await? + .into_parts(); + let status = parts.status; + let headers = process_headers(&parts.headers); + let body = body.try_into_bytes().await?; Ok((status, body, headers)) } @@ -174,40 +241,35 @@ async fn request_sse( body: Vec, headers: Vec<(String, String)>, rpc_pipe: Rc, + cache: Cache, ) -> Result<(), String> { trace!("Entering SSE"); let guard = guard((), |_| trace!("Exiting SSE due to cancellation")); - let response = request(client, url, method, body, headers).await?; - - let content_type = response.headers().get("content-type"); - let is_event_stream = content_type - .and_then(|v| v.to_str().ok()) - .map(|s| s.starts_with("text/event-stream")) - .unwrap_or(false); - - if !is_event_stream { - let headers = process_headers(response.headers()); - let status = response.status(); - let body = match response.bytes().await { - Ok(bytes) => bytes.to_vec(), - Err(e) => { - return Err(format!("Failed to read response body: {e:?}")); - } - }; - _ = rpc_pipe - .write(RustToPythonMessage::Response( - id, - (status.as_u16(), body, headers), - )) - .await; - - trace!("Exiting SSE due to non-SSE response"); - ScopeGuard::into_inner(guard); - return Ok(()); - } + let (parts, body) = request(client, url, method, body, headers, false, cache) + .await? + .into_parts(); + + let mut stream = match body { + MaybeResponse::Bytes(bytes) => { + let headers = process_headers(&parts.headers); + let status = parts.status; + let body = bytes; + _ = rpc_pipe + .write(RustToPythonMessage::Response( + id, + (status.as_u16(), body, headers), + )) + .await; + + trace!("Exiting SSE due to non-SSE response"); + ScopeGuard::into_inner(guard); + return Ok(()); + } + MaybeResponse::Stream(body) => body.into_data_stream().eventsource(), + }; - let headers = process_headers(response.headers()); - let status = response.status(); + let headers = process_headers(&parts.headers); + let status = parts.status; _ = rpc_pipe .write(RustToPythonMessage::SSEStart( id, @@ -215,7 +277,6 @@ async fn request_sse( )) .await; - let mut stream = response.bytes_stream().eventsource(); loop { let chunk = match stream.try_next().await { Ok(None) => break, @@ -247,7 +308,7 @@ async fn request_sse( Ok(()) } -fn process_headers(headers: &reqwest::header::HeaderMap) -> HashMap { +fn process_headers(headers: &HeaderMap) -> HashMap { headers .iter() .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) @@ -411,6 +472,8 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { .pool_idle_timeout(POOL_IDLE_TIMEOUT); let client_sse = client_sse.build().unwrap(); + let cache = Cache::new(); + let permit_manager = Rc::new(PermitManager::new(capacity)); let tasks = Arc::new(Mutex::new(HashMap::::new())); @@ -442,6 +505,7 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { client, client_sse, rpc_pipe, + cache.clone(), )); if let (Some(id), Some(backpressure)) = (id, backpressure) { tasks @@ -462,6 +526,7 @@ async fn execute( client: reqwest::Client, client_sse: reqwest::Client, rpc_pipe: Rc, + cache: Cache, ) { // If a request task was booted by this request, remove it from the list of // tasks when we exit. @@ -474,11 +539,11 @@ async fn execute( UpdateLimit(limit) => { permit_manager.update_limit(limit); } - Request(id, url, method, body, headers) => { + Request(id, url, method, body, headers, allow_cache) => { let Ok(permit) = permit_manager.acquire().await else { return; }; - match request_bytes(client, url, method, body, headers).await { + match request_bytes(client, url, method, body, headers, allow_cache, cache).await { Ok((status, body, headers)) => { _ = rpc_pipe .write(RustToPythonMessage::Response( @@ -513,6 +578,7 @@ async fn execute( body, headers, rpc_pipe.clone(), + cache, ) .await { @@ -609,8 +675,11 @@ impl Http { method: String, body: Vec, headers: Vec<(String, String)>, + cache: bool, ) -> PyResult<()> { - self.send(PythonToRustMessage::Request(id, url, method, body, headers)) + self.send(PythonToRustMessage::Request( + id, url, method, body, headers, cache, + )) } fn _request_sse( diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 1fd6323e9352..0920a23e4883 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -2253,7 +2253,7 @@ async def test_http_auth_ext_generic_oidc_authorize_01(self): ) redirect_to = f"{self.http_addr}/some/path" - _, headers, status = self.http_con_request( + body, headers, status = self.http_con_request( http_con, { "provider": provider_name, @@ -2263,7 +2263,7 @@ async def test_http_auth_ext_generic_oidc_authorize_01(self): path="authorize", ) - self.assertEqual(status, 302) + self.assertEqual(status, 302, body) location = headers.get("location") assert location is not None