Skip to content

Commit

Permalink
Fix SET within transactions over SQL adapter (#7710)
Browse files Browse the repository at this point in the history
When using SET (or SET SESSION) within a transaction, this SET is
applied only after the transaction is committed. It should be applied
immediately.

---------

Co-authored-by: Fantix King <[email protected]>
  • Loading branch information
aljazerzen and fantix authored Sep 9, 2024
1 parent b6b3cb3 commit f8146ad
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 31 deletions.
7 changes: 5 additions & 2 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,12 +593,15 @@ class SQLTransactionState:

def current_fe_settings(self) -> SQLSettings:
if self.in_tx:
return self.in_tx_settings or DEFAULT_SQL_FE_SETTINGS
else:
return self.in_tx_local_settings or DEFAULT_SQL_FE_SETTINGS
else:
return self.settings or DEFAULT_SQL_FE_SETTINGS

def get(self, name: str) -> Optional[str | list[str]]:
if self.in_tx:
# For easier access, in_tx_local_settings is always a superset of
# in_tx_settings; in_tx_settings only keeps track of non-local
# settings, so that the local settings don't go across tx bounds
assert self.in_tx_local_settings
return self.in_tx_local_settings[name]
else:
Expand Down
41 changes: 25 additions & 16 deletions edb/server/protocol/pg_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,34 @@ def managed_error():
@cython.final
cdef class ConnectionView:
def __init__(self):
self._settings = DEFAULT_SETTINGS
self._fe_settings = DEFAULT_FE_SETTINGS
self._in_tx_explicit = False
self._in_tx_implicit = False

# Kepp track of backend settings so that we can sync to use different
# backend connections (pgcon) within the same frontend connection,
# see serialize_state() below and its usages in pgcon.pyx.
self._settings = DEFAULT_SETTINGS
self._in_tx_settings = None

# Frontend-only settings are defined by the high-level compiler, and
# tracked only here, syncing between the compiler process,
# see current_fe_settings(), fe_transaction_state() and usages below.
self._fe_settings = DEFAULT_FE_SETTINGS
self._in_tx_fe_settings = None
self._in_tx_fe_local_settings = None

self._in_tx_portals = {}
self._in_tx_new_portals = set()
self._in_tx_savepoints = collections.deque()
self._tx_error = False
self._session_state_db_cache = (DEFAULT_SETTINGS, DEFAULT_STATE)

def current_settings(self):
if self.in_tx():
return self._in_tx_settings or DEFAULT_SETTINGS
else:
return self._settings or DEFAULT_SETTINGS

cpdef inline current_fe_settings(self):
if self.in_tx():
# For easier access, _in_tx_fe_local_settings is always a superset
# of _in_tx_fe_settings; _in_tx_fe_settings only keeps track of
# non-local settings, so that the local settings don't go across
# transaction boundaries; this must be consistent with dbstate.py.
return self._in_tx_fe_local_settings or DEFAULT_FE_SETTINGS
else:
return self._fe_settings or DEFAULT_FE_SETTINGS
Expand All @@ -117,9 +124,7 @@ cdef class ConnectionView:
self._in_tx_explicit = chain_explicit
self._in_tx_settings = self._settings if self.in_tx() else None
self._in_tx_fe_settings = self._fe_settings if self.in_tx() else None
self._in_tx_fe_local_settings = (
self._fe_settings if self.in_tx() else None
)
self._in_tx_fe_local_settings = self._in_tx_fe_settings
self._in_tx_portals.clear()
self._in_tx_new_portals.clear()
self._in_tx_savepoints.clear()
Expand Down Expand Up @@ -247,8 +252,8 @@ cdef class ConnectionView:
else:
if self.in_tx():
if unit.frontend_only:
if unit.is_local:
settings = self._in_tx_fe_local_settings.mutate()
if not unit.is_local:
settings = self._in_tx_fe_settings.mutate()
for k, v in unit.set_vars.items():
if v is None:
if k in DEFAULT_FE_SETTINGS:
Expand All @@ -257,8 +262,8 @@ cdef class ConnectionView:
settings.pop(k, None)
else:
settings[k] = v
self._in_tx_fe_local_settings = settings.finish()
settings = self._in_tx_fe_settings.mutate()
self._in_tx_fe_settings = settings.finish()
settings = self._in_tx_fe_local_settings.mutate()
else:
settings = self._in_tx_settings.mutate()
elif not unit.is_local:
Expand All @@ -278,7 +283,7 @@ cdef class ConnectionView:
settings[k] = v
if self.in_tx():
if unit.frontend_only:
self._in_tx_fe_settings = settings.finish()
self._in_tx_fe_local_settings = settings.finish()
else:
self._in_tx_settings = settings.finish()
else:
Expand Down Expand Up @@ -991,7 +996,11 @@ cdef class PgConnection(frontend.FrontendConnection):
PGMessage parse_action
ConnectionView dbv

# Extended-query pre-plays on a deeply-cloned temporary dbview so as to
# compose the actions list with correct states; the actual changes to
# dbview is applied in pgcon.pyx when the actions are actually executed
dbv = copy.deepcopy(self._dbview)

actions = deque()
fresh_stmts = set()
in_implicit = self._dbview._in_tx_implicit
Expand Down
102 changes: 89 additions & 13 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ async def test_sql_query_introspection_04(self):
],
)

async def test_sql_query_schemas(self):
async def test_sql_query_schemas_01(self):
await self.scon.fetch('SELECT id FROM "inventory"."Item";')
await self.scon.fetch('SELECT id FROM "public"."Person";')

Expand Down Expand Up @@ -900,6 +900,79 @@ async def test_sql_query_schemas(self):
# HACK: Set search_path back to public
await self.scon.execute('SET search_path TO public;')

async def test_sql_query_set_01(self):
# initial state: search_path=public
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [['public']])

# enter transaction
tran = self.scon.transaction()
await tran.start()

# set
await self.scon.execute('SET LOCAL search_path TO inventory;')

await self.scon.fetch('SELECT id FROM "Item";')
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["'inventory'"]])

# finish
await tran.commit()

# because we used LOCAL, value should be reset after transaction is over
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["public"]])

async def test_sql_query_set_02(self):
# initial state: search_path=public
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [['public']])

# enter transaction
tran = self.scon.transaction()
await tran.start()

# set
await self.scon.execute('SET search_path TO inventory;')

res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["'inventory'"]])

# commit
await tran.commit()

# it should still be changed, since we SET was not LOCAL
await self.scon.fetch('SELECT id FROM "Item";')
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["'inventory'"]])

# reset to default value
await self.scon.execute('RESET search_path;')
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["public"]])

async def test_sql_query_set_03(self):
# initial state: search_path=public
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [['public']])

# start
tran = self.scon.transaction()
await tran.start()

# set
await self.scon.execute('SET search_path TO inventory;')

res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["'inventory'"]])

# rollback
await tran.rollback()

# because transaction was rolled back, value should be reset
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["public"]])

async def test_sql_query_static_eval_01(self):
res = await self.squery_values('select current_schema;')
self.assertEqual(res, [['public']])
Expand Down Expand Up @@ -988,18 +1061,21 @@ async def test_sql_query_static_eval_06(self):
ORDER BY relname;
"""
)
self.assertEqual(res, [
["Book", 8192],
["Book.chapters", 8192],
["Content", 8192],
["Genre", 8192],
["Movie", 8192],
["Movie.actors", 8192],
["Movie.director", 8192],
["Person", 8192],
["novel", 8192],
["novel.chapters", 0],
])
self.assertEqual(
res,
[
["Book", 8192],
["Book.chapters", 8192],
["Content", 8192],
["Genre", 8192],
["Movie", 8192],
["Movie.actors", 8192],
["Movie.director", 8192],
["Person", 8192],
["novel", 8192],
["novel.chapters", 0],
],
)

async def test_sql_query_be_state(self):
con = await self.connect(database=self.con.dbname)
Expand Down

0 comments on commit f8146ad

Please sign in to comment.