diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 77e84ca5905..3d7baf4168b 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -129,6 +129,7 @@ cdef class CompiledQuery: data = {} if self.tag: data['tag'] = self.tag + # maintenance reminder: please also update _amend_typedesc_in_sql() if data: data_bytes = json.dumps(data).encode(defines.EDGEDB_ENCODING) return b''.join([b'-- ', data_bytes, b'\n']) @@ -1311,6 +1312,7 @@ cdef class DatabaseConnectionView: use_metrics: bint = True, allow_capabilities: uint64_t = compiler.Capability.ALL, pgcon: pgcon.PGConnection | None = None, + tag: str | None = None, ) -> CompiledQuery: query_unit_group = None if self._query_cache_enabled: @@ -1410,6 +1412,7 @@ cdef class DatabaseConnectionView: query_req, query_unit_group, pgcon, + tag, ) if self._query_cache_enabled and query_unit_group.cacheable: @@ -1486,6 +1489,7 @@ cdef class DatabaseConnectionView: query_req: rpc.CompilationRequest, qug: dbstate.QueryUnitGroup, pgcon: pgcon.PGConnection, + tag: str | None, ) -> None: # The SQL QueryUnitGroup as initially returned from the compiler # is missing the input/output type descriptors because we currently @@ -1516,6 +1520,15 @@ cdef class DatabaseConnectionView: intro_sql = query_unit.introspection_sql if intro_sql is None: intro_sql = query_unit.sql + if tag is not None: + # maintenance reminder: please also update make_query_prefix() + tag_json = json.dumps({"tag": tag}) + intro_sql = b''.join([ + b'-- ', + tag_json.encode(defines.EDGEDB_ENCODING), + b'\n', + intro_sql, + ]) try: param_desc, result_desc = await pgcon.sql_describe( intro_sql, all_type_oids) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 801fee58ce2..78d17615e22 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -507,6 +507,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): self, rpc.CompilationRequest query_req, uint64_t allow_capabilities, + tag=None, ) -> dbview.CompiledQuery: cdef dbview.DatabaseConnectionView dbv dbv = self.get_dbview() @@ -546,6 +547,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): query_req, allow_capabilities=allow_capabilities, pgcon=pg_conn, + tag=tag, ) else: return await dbv.parse( @@ -957,7 +959,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): if self.debug: self.debug_print('EXECUTE /CACHE MISS', query_req.source.text()) - compiled = await self._parse(query_req, allow_capabilities) + compiled = await self._parse( + query_req, allow_capabilities, tag + ) query_unit_group = compiled.query_unit_group if self._cancelled: raise ConnectionAbortedError diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 7f94e254652..97b80f23188 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -367,6 +367,9 @@ def remove_log_listener(self, callback): def _get_state(self): return self._options.state + def _get_annotations(self) -> typing.Dict[str, str]: + return self._options.annotations + def _warning_handler(self, warnings, res): if self._capture_warnings is not None: self._capture_warnings.extend(warnings) diff --git a/tests/test_edgeql_sys.py b/tests/test_edgeql_sys.py index 319fe92d7ae..819af1851fb 100644 --- a/tests/test_edgeql_sys.py +++ b/tests/test_edgeql_sys.py @@ -44,8 +44,11 @@ def _before_test_sys_query_stats(self): "can't run query stats test when extension isn't present" ) - async def _test_sys_query_stats(self): - stats_query = f''' + def make_stats_query(self, tag: str | None = None) -> str: + tag_filter = '' + if tag is not None: + tag_filter = f'and .tag = {common.quote_literal(tag)}' + return f''' with stats := ( select sys::QueryStats @@ -53,10 +56,14 @@ async def _test_sys_query_stats(self): .query like '%{self.stats_magic_word}%' and .query not like '%sys::%' and .query_type = $0 + {tag_filter} ) select sum(stats.calls) ''' + async def _test_sys_query_stats(self): + stats_query = self.make_stats_query() + # Take the initial tracking number of executions calls = await self.con.query_single(stats_query, self.stats_type) @@ -104,9 +111,6 @@ async def _test_sys_query_stats(self): ) # Turn cfg::Config.track_query_stats back on again - if self.stats_type == 'SQL': - # FIXME: don't return after fixing #8147 - return await self._configure_track('All') await self._query_for_stats() self.assertEqual( @@ -114,6 +118,25 @@ async def _test_sys_query_stats(self): 1, ) + async def _test_sys_query_stats_with_tag(self): + # Test tags are correctly set + tag = 'test_tag' + self.con = self.con.with_annotation('tag', tag) + self.stats_magic_word += "Tagged" + self.assertEqual( + await self.con.query_single( + self.make_stats_query(tag=tag), self.stats_type + ), + 0, + ) + await self._query_for_stats() + self.assertEqual( + await self.con.query_single( + self.make_stats_query(tag=tag), self.stats_type + ), + 1, + ) + class TestEdgeQLSys(tb.QueryTestCase, TestQueryStatsMixin): stats_magic_word = 'TestEdgeQLSys' @@ -184,6 +207,7 @@ async def test_edgeql_sys_query_stats(self): self.con = await sd.connect() try: await self._test_sys_query_stats() + await self._test_sys_query_stats_with_tag() finally: await self.con.aclose() self.con = old_con @@ -208,9 +232,8 @@ async def _query_for_stats(self): async def _configure_track(self, option: str): # XXX: we should probably translate the config name in the compiler, # so that we can use the frontend name (track_query_stats) here instead - # FIXME: drop lower() after fixing #8147 await self.scon.execute(f''' - set "edb_stat_statements.track" TO '{option.lower()}'; + set "edb_stat_statements.track" TO '{option}'; ''') async def _bad_query_for_stats(self): @@ -235,3 +258,43 @@ async def test_sql_sys_query_stats(self): await self.scon.close() await self.con.aclose() self.con, self.scon = old_cons + + +class TestQueryStatsSQLoverBianry(tb.QueryTestCase, TestQueryStatsMixin): + stats_magic_word = 'TestEdgeQLSysSQL' + stats_type = 'SQL' + + async def _query_for_stats(self): + self.counter += 1 + ident = self.stats_magic_word + str(self.counter) + records = await self.con.query_sql( + f"select {self.counter} as {common.quote_ident(ident)}" + ) + self.assertEqual(len(records), 1) + self.assertEqual(records[0].as_dict(), {ident: self.counter}) + + async def _configure_track(self, option: str): + await self.con.query(f''' + configure session set track_query_stats := + {common.quote_literal(option)}; + ''') + + async def _bad_query_for_stats(self): + async with self.assertRaisesRegexTx( + edgedb.QueryError, 'does not exist' + ): + await self.con.query_sql( + f'select {self.stats_magic_word}_NoSuchType' + ) + + async def test_edgeql_sys_query_stats_sql(self): + self._before_test_sys_query_stats() + async with tb.start_edgedb_server() as sd: + old_con = self.con + self.con = await sd.connect() + try: + await self._test_sys_query_stats() + await self._test_sys_query_stats_with_tag() + finally: + await self.con.aclose() + self.con = old_con