Skip to content

Commit

Permalink
Fix tagging SQL over binary protocol (geldata#8371)
Browse files Browse the repository at this point in the history
SQL over binary protocol will `Parse` first for `_amend_typedesc_in_sql()`, where we didn't send tags and the edb_stat_statements remembered that for all future commands of the same query.
  • Loading branch information
fantix authored Feb 26, 2025
1 parent 344315b commit 2fec20f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 8 deletions.
13 changes: 13 additions & 0 deletions edb/server/dbview/dbview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cdef class CompiledQuery:
data = {}
if self.tag:
data['tag'] = self.tag
# maintenance reminder: please also update _amend_typedesc_in_sql()
if data:
data_bytes = json.dumps(data).encode(defines.EDGEDB_ENCODING)
return b''.join([b'-- ', data_bytes, b'\n'])
Expand Down Expand Up @@ -1311,6 +1312,7 @@ cdef class DatabaseConnectionView:
use_metrics: bint = True,
allow_capabilities: uint64_t = <uint64_t>compiler.Capability.ALL,
pgcon: pgcon.PGConnection | None = None,
tag: str | None = None,
) -> CompiledQuery:
query_unit_group = None
if self._query_cache_enabled:
Expand Down Expand Up @@ -1410,6 +1412,7 @@ cdef class DatabaseConnectionView:
query_req,
query_unit_group,
pgcon,
tag,
)

if self._query_cache_enabled and query_unit_group.cacheable:
Expand Down Expand Up @@ -1486,6 +1489,7 @@ cdef class DatabaseConnectionView:
query_req: rpc.CompilationRequest,
qug: dbstate.QueryUnitGroup,
pgcon: pgcon.PGConnection,
tag: str | None,
) -> None:
# The SQL QueryUnitGroup as initially returned from the compiler
# is missing the input/output type descriptors because we currently
Expand Down Expand Up @@ -1516,6 +1520,15 @@ cdef class DatabaseConnectionView:
intro_sql = query_unit.introspection_sql
if intro_sql is None:
intro_sql = query_unit.sql
if tag is not None:
# maintenance reminder: please also update make_query_prefix()
tag_json = json.dumps({"tag": tag})
intro_sql = b''.join([
b'-- ',
tag_json.encode(defines.EDGEDB_ENCODING),
b'\n',
intro_sql,
])
try:
param_desc, result_desc = await pgcon.sql_describe(
intro_sql, all_type_oids)
Expand Down
6 changes: 5 additions & 1 deletion edb/server/protocol/binary.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ cdef class EdgeConnection(frontend.FrontendConnection):
self,
rpc.CompilationRequest query_req,
uint64_t allow_capabilities,
tag=None,
) -> dbview.CompiledQuery:
cdef dbview.DatabaseConnectionView dbv
dbv = self.get_dbview()
Expand Down Expand Up @@ -546,6 +547,7 @@ cdef class EdgeConnection(frontend.FrontendConnection):
query_req,
allow_capabilities=allow_capabilities,
pgcon=pg_conn,
tag=tag,
)
else:
return await dbv.parse(
Expand Down Expand Up @@ -957,7 +959,9 @@ cdef class EdgeConnection(frontend.FrontendConnection):
if self.debug:
self.debug_print('EXECUTE /CACHE MISS', query_req.source.text())

compiled = await self._parse(query_req, allow_capabilities)
compiled = await self._parse(
query_req, allow_capabilities, tag
)
query_unit_group = compiled.query_unit_group
if self._cancelled:
raise ConnectionAbortedError
Expand Down
3 changes: 3 additions & 0 deletions edb/testbase/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def remove_log_listener(self, callback):
def _get_state(self):
return self._options.state

def _get_annotations(self) -> typing.Dict[str, str]:
return self._options.annotations

def _warning_handler(self, warnings, res):
if self._capture_warnings is not None:
self._capture_warnings.extend(warnings)
Expand Down
77 changes: 70 additions & 7 deletions tests/test_edgeql_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,26 @@ def _before_test_sys_query_stats(self):
"can't run query stats test when extension isn't present"
)

async def _test_sys_query_stats(self):
stats_query = f'''
def make_stats_query(self, tag: str | None = None) -> str:
tag_filter = ''
if tag is not None:
tag_filter = f'and .tag = {common.quote_literal(tag)}'
return f'''
with stats := (
select
sys::QueryStats
filter
.query like '%{self.stats_magic_word}%'
and .query not like '%sys::%'
and .query_type = <sys::QueryType>$0
{tag_filter}
)
select sum(stats.calls)
'''

async def _test_sys_query_stats(self):
stats_query = self.make_stats_query()

# Take the initial tracking number of executions
calls = await self.con.query_single(stats_query, self.stats_type)

Expand Down Expand Up @@ -104,16 +111,32 @@ async def _test_sys_query_stats(self):
)

# Turn cfg::Config.track_query_stats back on again
if self.stats_type == 'SQL':
# FIXME: don't return after fixing #8147
return
await self._configure_track('All')
await self._query_for_stats()
self.assertEqual(
await self.con.query_single(stats_query, self.stats_type),
1,
)

async def _test_sys_query_stats_with_tag(self):
# Test tags are correctly set
tag = 'test_tag'
self.con = self.con.with_annotation('tag', tag)
self.stats_magic_word += "Tagged"
self.assertEqual(
await self.con.query_single(
self.make_stats_query(tag=tag), self.stats_type
),
0,
)
await self._query_for_stats()
self.assertEqual(
await self.con.query_single(
self.make_stats_query(tag=tag), self.stats_type
),
1,
)


class TestEdgeQLSys(tb.QueryTestCase, TestQueryStatsMixin):
stats_magic_word = 'TestEdgeQLSys'
Expand Down Expand Up @@ -184,6 +207,7 @@ async def test_edgeql_sys_query_stats(self):
self.con = await sd.connect()
try:
await self._test_sys_query_stats()
await self._test_sys_query_stats_with_tag()
finally:
await self.con.aclose()
self.con = old_con
Expand All @@ -208,9 +232,8 @@ async def _query_for_stats(self):
async def _configure_track(self, option: str):
# XXX: we should probably translate the config name in the compiler,
# so that we can use the frontend name (track_query_stats) here instead
# FIXME: drop lower() after fixing #8147
await self.scon.execute(f'''
set "edb_stat_statements.track" TO '{option.lower()}';
set "edb_stat_statements.track" TO '{option}';
''')

async def _bad_query_for_stats(self):
Expand All @@ -235,3 +258,43 @@ async def test_sql_sys_query_stats(self):
await self.scon.close()
await self.con.aclose()
self.con, self.scon = old_cons


class TestQueryStatsSQLoverBianry(tb.QueryTestCase, TestQueryStatsMixin):
stats_magic_word = 'TestEdgeQLSysSQL'
stats_type = 'SQL'

async def _query_for_stats(self):
self.counter += 1
ident = self.stats_magic_word + str(self.counter)
records = await self.con.query_sql(
f"select {self.counter} as {common.quote_ident(ident)}"
)
self.assertEqual(len(records), 1)
self.assertEqual(records[0].as_dict(), {ident: self.counter})

async def _configure_track(self, option: str):
await self.con.query(f'''
configure session set track_query_stats :=
<cfg::QueryStatsOption>{common.quote_literal(option)};
''')

async def _bad_query_for_stats(self):
async with self.assertRaisesRegexTx(
edgedb.QueryError, 'does not exist'
):
await self.con.query_sql(
f'select {self.stats_magic_word}_NoSuchType'
)

async def test_edgeql_sys_query_stats_sql(self):
self._before_test_sys_query_stats()
async with tb.start_edgedb_server() as sd:
old_con = self.con
self.con = await sd.connect()
try:
await self._test_sys_query_stats()
await self._test_sys_query_stats_with_tag()
finally:
await self.con.aclose()
self.con = old_con

0 comments on commit 2fec20f

Please sign in to comment.