diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 0e96a4f1f98..2afb8a043e0 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -593,12 +593,15 @@ class SQLTransactionState: def current_fe_settings(self) -> SQLSettings: if self.in_tx: - return self.in_tx_settings or DEFAULT_SQL_FE_SETTINGS - else: return self.in_tx_local_settings or DEFAULT_SQL_FE_SETTINGS + else: + return self.settings or DEFAULT_SQL_FE_SETTINGS def get(self, name: str) -> Optional[str | list[str]]: if self.in_tx: + # For easier access, in_tx_local_settings is always a superset of + # in_tx_settings; in_tx_settings only keeps track of non-local + # settings, so that the local settings don't go across tx bounds assert self.in_tx_local_settings return self.in_tx_local_settings[name] else: diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index ba08014b765..9ea121fda81 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -70,27 +70,34 @@ def managed_error(): @cython.final cdef class ConnectionView: def __init__(self): - self._settings = DEFAULT_SETTINGS - self._fe_settings = DEFAULT_FE_SETTINGS self._in_tx_explicit = False self._in_tx_implicit = False + + # Kepp track of backend settings so that we can sync to use different + # backend connections (pgcon) within the same frontend connection, + # see serialize_state() below and its usages in pgcon.pyx. + self._settings = DEFAULT_SETTINGS self._in_tx_settings = None + + # Frontend-only settings are defined by the high-level compiler, and + # tracked only here, syncing between the compiler process, + # see current_fe_settings(), fe_transaction_state() and usages below. + self._fe_settings = DEFAULT_FE_SETTINGS self._in_tx_fe_settings = None self._in_tx_fe_local_settings = None + self._in_tx_portals = {} self._in_tx_new_portals = set() self._in_tx_savepoints = collections.deque() self._tx_error = False self._session_state_db_cache = (DEFAULT_SETTINGS, DEFAULT_STATE) - def current_settings(self): - if self.in_tx(): - return self._in_tx_settings or DEFAULT_SETTINGS - else: - return self._settings or DEFAULT_SETTINGS - cpdef inline current_fe_settings(self): if self.in_tx(): + # For easier access, _in_tx_fe_local_settings is always a superset + # of _in_tx_fe_settings; _in_tx_fe_settings only keeps track of + # non-local settings, so that the local settings don't go across + # transaction boundaries; this must be consistent with dbstate.py. return self._in_tx_fe_local_settings or DEFAULT_FE_SETTINGS else: return self._fe_settings or DEFAULT_FE_SETTINGS @@ -117,9 +124,7 @@ cdef class ConnectionView: self._in_tx_explicit = chain_explicit self._in_tx_settings = self._settings if self.in_tx() else None self._in_tx_fe_settings = self._fe_settings if self.in_tx() else None - self._in_tx_fe_local_settings = ( - self._fe_settings if self.in_tx() else None - ) + self._in_tx_fe_local_settings = self._in_tx_fe_settings self._in_tx_portals.clear() self._in_tx_new_portals.clear() self._in_tx_savepoints.clear() @@ -247,8 +252,8 @@ cdef class ConnectionView: else: if self.in_tx(): if unit.frontend_only: - if unit.is_local: - settings = self._in_tx_fe_local_settings.mutate() + if not unit.is_local: + settings = self._in_tx_fe_settings.mutate() for k, v in unit.set_vars.items(): if v is None: if k in DEFAULT_FE_SETTINGS: @@ -257,8 +262,8 @@ cdef class ConnectionView: settings.pop(k, None) else: settings[k] = v - self._in_tx_fe_local_settings = settings.finish() - settings = self._in_tx_fe_settings.mutate() + self._in_tx_fe_settings = settings.finish() + settings = self._in_tx_fe_local_settings.mutate() else: settings = self._in_tx_settings.mutate() elif not unit.is_local: @@ -278,7 +283,7 @@ cdef class ConnectionView: settings[k] = v if self.in_tx(): if unit.frontend_only: - self._in_tx_fe_settings = settings.finish() + self._in_tx_fe_local_settings = settings.finish() else: self._in_tx_settings = settings.finish() else: @@ -991,7 +996,11 @@ cdef class PgConnection(frontend.FrontendConnection): PGMessage parse_action ConnectionView dbv + # Extended-query pre-plays on a deeply-cloned temporary dbview so as to + # compose the actions list with correct states; the actual changes to + # dbview is applied in pgcon.pyx when the actions are actually executed dbv = copy.deepcopy(self._dbview) + actions = deque() fresh_stmts = set() in_implicit = self._dbview._in_tx_implicit diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 309c5266448..5b9d0d67032 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -859,7 +859,7 @@ async def test_sql_query_introspection_04(self): ], ) - async def test_sql_query_schemas(self): + async def test_sql_query_schemas_01(self): await self.scon.fetch('SELECT id FROM "inventory"."Item";') await self.scon.fetch('SELECT id FROM "public"."Person";') @@ -900,6 +900,79 @@ async def test_sql_query_schemas(self): # HACK: Set search_path back to public await self.scon.execute('SET search_path TO public;') + async def test_sql_query_set_01(self): + # initial state: search_path=public + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [['public']]) + + # enter transaction + tran = self.scon.transaction() + await tran.start() + + # set + await self.scon.execute('SET LOCAL search_path TO inventory;') + + await self.scon.fetch('SELECT id FROM "Item";') + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["'inventory'"]]) + + # finish + await tran.commit() + + # because we used LOCAL, value should be reset after transaction is over + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["public"]]) + + async def test_sql_query_set_02(self): + # initial state: search_path=public + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [['public']]) + + # enter transaction + tran = self.scon.transaction() + await tran.start() + + # set + await self.scon.execute('SET search_path TO inventory;') + + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["'inventory'"]]) + + # commit + await tran.commit() + + # it should still be changed, since we SET was not LOCAL + await self.scon.fetch('SELECT id FROM "Item";') + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["'inventory'"]]) + + # reset to default value + await self.scon.execute('RESET search_path;') + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["public"]]) + + async def test_sql_query_set_03(self): + # initial state: search_path=public + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [['public']]) + + # start + tran = self.scon.transaction() + await tran.start() + + # set + await self.scon.execute('SET search_path TO inventory;') + + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["'inventory'"]]) + + # rollback + await tran.rollback() + + # because transaction was rolled back, value should be reset + res = await self.squery_values('SHOW search_path;') + self.assertEqual(res, [["public"]]) + async def test_sql_query_static_eval_01(self): res = await self.squery_values('select current_schema;') self.assertEqual(res, [['public']]) @@ -988,18 +1061,21 @@ async def test_sql_query_static_eval_06(self): ORDER BY relname; """ ) - self.assertEqual(res, [ - ["Book", 8192], - ["Book.chapters", 8192], - ["Content", 8192], - ["Genre", 8192], - ["Movie", 8192], - ["Movie.actors", 8192], - ["Movie.director", 8192], - ["Person", 8192], - ["novel", 8192], - ["novel.chapters", 0], - ]) + self.assertEqual( + res, + [ + ["Book", 8192], + ["Book.chapters", 8192], + ["Content", 8192], + ["Genre", 8192], + ["Movie", 8192], + ["Movie.actors", 8192], + ["Movie.director", 8192], + ["Person", 8192], + ["novel", 8192], + ["novel.chapters", 0], + ], + ) async def test_sql_query_be_state(self): con = await self.connect(database=self.con.dbname)