Skip to content

Commit

Permalink
Implement caching for JWKSet (#8332)
Browse files Browse the repository at this point in the history
In the previous HTTP rewrite, we accidentally removed caching for JWK
keys and OIDC configurations. This restores it using an in-memory cache.

This adds a cache parameter to the Python-side `get(...)`.

The cache is enabled only for OIDC and JWK discovery requests -- the
assumption is that we'd prefer fresh user information when fetching
during login over speed.
  • Loading branch information
mmastrac authored Feb 12, 2025
1 parent 4bcd38c commit 2e9f16e
Show file tree
Hide file tree
Showing 14 changed files with 673 additions and 103 deletions.
137 changes: 112 additions & 25 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ members = [
"rust/gel-auth",
"rust/gel-jwt",
"rust/gel-stream",
"rust/gel-http",
"rust/pgrust",
"rust/http",
"rust/pyo3_util",
]
resolver = "2"
Expand All @@ -30,7 +30,7 @@ db_proto = { path = "rust/db_proto" }
captive_postgres = { path = "rust/captive_postgres" }
conn_pool = { path = "rust/conn_pool" }
pgrust = { path = "rust/pgrust" }
http = { path = "rust/http" }
gel-http = { path = "rust/gel-http" }
pyo3_util = { path = "rust/pyo3_util" }

[profile.release]
Expand Down
2 changes: 1 addition & 1 deletion edb/server/_rust_native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pyo3 = { workspace = true }
pyo3_util.workspace = true
conn_pool = { workspace = true, features = [ "python_extension" ] }
pgrust = { workspace = true, features = [ "python_extension" ] }
http = { workspace = true, features = [ "python_extension" ] }
gel-http = { workspace = true, features = [ "python_extension" ] }
gel-auth = { workspace = true, features = [ "python_extension" ] }
gel-jwt = { workspace = true, features = [ "python_extension" ] }

Expand Down
2 changes: 1 addition & 1 deletion edb/server/_rust_native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn _rust_native(py: Python, m: &Bound<PyModule>) -> PyResult<()> {

add_child_module(py, m, "_conn_pool", conn_pool::python::_conn_pool)?;
add_child_module(py, m, "_pg_rust", pgrust::python::_pg_rust)?;
add_child_module(py, m, "_http", http::python::_http)?;
add_child_module(py, m, "_http", gel_http::python::_gel_http)?;
add_child_module(py, m, "_jwt", gel_jwt::python::_jwt)?;

Ok(())
Expand Down
31 changes: 26 additions & 5 deletions edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ async def request(
headers: HeaderType = None,
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
cache: bool = False,
) -> tuple[int, bytearray, dict[str, str]]:
self._ensure_task()
path = self._process_path(path)
Expand All @@ -191,7 +192,9 @@ async def request(
self._requests[id] = asyncio.Future()
start_time = time.monotonic()
try:
self._ensure_client()._request(id, path, method, data, headers_list)
self._ensure_client()._request(
id, path, method, data, headers_list, cache
)
resp = await self._requests[id]
if self._stat_callback:
status_code, body, headers = resp
Expand All @@ -217,9 +220,15 @@ async def request(
finally:
del self._requests[id]

async def get(self, path: str, *, headers: HeaderType = None) -> Response:
async def get(
self,
path: str,
*,
headers: HeaderType = None,
cache: bool = False,
) -> Response:
result = await self.request(
method="GET", path=path, data=None, headers=headers
method="GET", path=path, data=None, headers=headers, cache=cache
)
return Response.from_tuple(result)

Expand Down Expand Up @@ -387,12 +396,24 @@ def _process_path(self, path):
return path

async def request(
self, *, method, path, headers=None, data=None, json=None
self,
*,
method,
path,
headers=None,
data=None,
json=None,
cache=False,
):
path = self._process_path(path)
headers = self._process_headers(headers)
return await self.http_client.request(
method=method, path=path, headers=headers, data=data, json=json
method=method,
path=path,
headers=headers,
data=data,
json=json,
cache=cache,
)

async def stream_sse(
Expand Down
7 changes: 5 additions & 2 deletions edb/server/protocol/auth_ext/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def fetch_user_info(
async with self.http_factory(
base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}"
) as client:
r = await client.get(jwks_uri.path)
r = await client.get(jwks_uri.path, cache=True)

# Load the token as a JWT object and verify it directly
try:
Expand All @@ -178,6 +178,9 @@ async def fetch_user_info(

async def _get_oidc_config(self) -> data.OpenIDConfig:
client = self.http_factory(base_url=self.issuer_url)
response = await client.get('/.well-known/openid-configuration')
response = await client.get(
'/.well-known/openid-configuration',
cache=True
)
config = response.json()
return data.OpenIDConfig(**config)
2 changes: 1 addition & 1 deletion edb/server/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_http_client(self, *, originator: str) -> HttpClient:
user_agent=f"EdgeDB {buildmeta.get_version_string(short=True)}",
stat_callback=lambda stat: logger.debug(
f"HTTP stat: {originator} {stat}"
),
)
)
return self._http_client

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies = [
'psutil~=5.8',
'setproctitle~=1.2',

'hishel==0.0.24',
'webauthn~=2.0.0',
'argon2-cffi~=23.1.0',
'aiosmtplib~=3.0',
Expand Down
Loading

0 comments on commit 2e9f16e

Please sign in to comment.