diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index 33297da7a19..805d89a9086 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -1501,7 +1501,7 @@ cdef class PgConnection(frontend.FrontendConnection): if self.debug: self.debug_print("Compile", query_str) fe_settings = dbv.current_fe_settings() - key = (hashlib.sha1(query_str.encode("utf-8")).digest(), fe_settings) + key = compute_cache_key(query_str, fe_settings) ignore_cache |= self._disable_cache @@ -1584,6 +1584,17 @@ cdef class PgConnection(frontend.FrontendConnection): return qu +def compute_cache_key( + query_str: str, fe_settings: dbstate.SQLSettings +) -> bytes: + h = hashlib.blake2b(query_str.encode("utf-8")) + for key, value in fe_settings.items(): + if key.startswith('global '): + continue + h.update(hash(value).to_bytes(8, signed=True)) + return h.digest() + + cdef WriteBuffer remap_arguments( data: bytes, params: list[dbstate.SQLParam] | None, diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 32e66d13b5c..051fbe1259a 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -2269,6 +2269,18 @@ async def connect(self, **kwargs: Any) -> tconn.Connection: conn_args = self.get_connect_args(**kwargs) return await tconn.async_connect_test_client(**conn_args) + async def connect_pg(self, **kwargs: Any) -> asyncpg.Connection: + import asyncpg + + conn_args = self.get_connect_args(**kwargs) + return await asyncpg.connect( + host=conn_args['host'], + port=conn_args['port'], + user=conn_args['user'], + password=conn_args['password'], + ssl='require' + ) + async def connect_test_protocol(self, **kwargs): conn_args = self.get_connect_args(**kwargs) conn = await test_protocol.new_connection(**conn_args) diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 9d4aaf542d9..843f70f0507 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -727,6 +727,14 @@ def measure_compilations( '{tenant="localtest",path="compiler"}' ) or 0 + def measure_sql_compilations( + sd: tb._EdgeDBServerData + ) -> Callable[[], float | int]: + return lambda: tb.parse_metrics(sd.fetch_metrics()).get( + 'edgedb_server_sql_compilations_total' + '{tenant="localtest"}' + ) or 0 + with tempfile.TemporaryDirectory() as temp_dir: async with tb.start_edgedb_server( data_dir=temp_dir, @@ -822,6 +830,42 @@ def measure_compilations( finally: await con.aclose() + has_asyncpg = True + try: + import asyncpg # noqa + except ImportError: + has_asyncpg = False + + if has_asyncpg: + scon = await sd.connect_pg() + try: + with self.assertChange(measure_sql_compilations(sd), 1): + await scon.fetch('select 1') + + with self.assertChange(measure_sql_compilations(sd), 1): + await scon.fetch('select 1 + 1') + + # cache hit + with self.assertChange(measure_sql_compilations(sd), 0): + await scon.fetch('select 1') + + # TODO: normalization & constant extraction + with self.assertChange(measure_sql_compilations(sd), 2): + await scon.fetch('select 2') + await scon.fetch('sELEcT 1') + + # cache hit, even after global has been changed + await scon.execute('SET "global default::g" to 1') + with self.assertChange(measure_sql_compilations(sd), 0): + await scon.execute('select 1') + + # compiler call, because config was changed + await scon.execute('SET apply_access_policies_sql to 1') + with self.assertChange(measure_sql_compilations(sd), 1): + await scon.execute('select 1') + finally: + await scon.close() + # Now restart the server to test the cache persistence. async with tb.start_edgedb_server( data_dir=temp_dir,