From 8319bbe19692b3b52d7656190282849cf8ebe7fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Thu, 17 Oct 2024 19:26:03 +0200 Subject: [PATCH] Unify SQL connection settings by falling back to database config (#7863) --- docs/stdlib/cfg.rst | 4 ++ edb/lib/cfg.edgeql | 8 +++ edb/server/compiler/compiler.py | 12 +++++ edb/server/compiler/dbstate.py | 2 - edb/server/compiler/sql.py | 28 ++++++++--- tests/test_sql_query.py | 89 +++++++++++++++++++++++++++++---- 6 files changed, 125 insertions(+), 18 deletions(-) diff --git a/docs/stdlib/cfg.rst b/docs/stdlib/cfg.rst index dd23d2caa0c..18fbc11cef7 100644 --- a/docs/stdlib/cfg.rst +++ b/docs/stdlib/cfg.rst @@ -159,6 +159,10 @@ Query behavior UI session, so you won't have to remember to re-enable it when you're done. +:eql:synopsis:`apply_access_policies_sql -> bool` + Determines whether access policies should be applied when running queries over + SQL adapter. Defaults to ``false``. + :eql:synopsis:`force_database_error -> str` A hook to force all queries to produce an error. Defaults to 'false'. diff --git a/edb/lib/cfg.edgeql b/edb/lib/cfg.edgeql index eab87ce9349..bdca5d2300b 100644 --- a/edb/lib/cfg.edgeql +++ b/edb/lib/cfg.edgeql @@ -185,6 +185,14 @@ ALTER TYPE cfg::AbstractConfig { 'Whether access policies will be applied when running queries.'; }; + CREATE PROPERTY apply_access_policies_sql -> std::bool { + SET default := false; + CREATE ANNOTATION cfg::affects_compilation := 'false'; + CREATE ANNOTATION std::description := + 'Whether access policies will be applied when running queries over \ + SQL adapter.'; + }; + CREATE PROPERTY allow_user_specified_id -> std::bool { SET default := false; CREATE ANNOTATION cfg::affects_compilation := 'true'; diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index d88ba974ace..a3922c6db30 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -544,6 +544,16 @@ def compile_sql( ) schema = state.current_tx().get_schema(self.state.std_schema) + setting = database_config.get('allow_user_specified_id', None) + allow_user_specified_id = None + if setting and setting.value: + allow_user_specified_id = sql.is_setting_truthy(setting.value) + + setting = database_config.get('apply_access_policies_sql', None) + apply_access_policies_sql = None + if setting and setting.value: + apply_access_policies_sql = sql.is_setting_truthy(setting.value) + return sql.compile_sql( query_str, schema=schema, @@ -551,6 +561,8 @@ def compile_sql( prepared_stmt_map=prepared_stmt_map, current_database=current_database, current_user=current_user, + allow_user_specified_id=allow_user_specified_id, + apply_access_policies_sql=apply_access_policies_sql, ) def compile_request( diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 8b3cebb5e48..4ff1d55c1a1 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -629,8 +629,6 @@ class ParsedDatabase: DEFAULT_SQL_SETTINGS: SQLSettings = immutables.Map() DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map({ "search_path": ("public",), - "allow_user_specified_id": ("false",), - "apply_access_policies_sql": ("false",), "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), "server_version_num": cast( SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) diff --git a/edb/server/compiler/sql.py b/edb/server/compiler/sql.py index 4176bc9e146..445abd21356 100644 --- a/edb/server/compiler/sql.py +++ b/edb/server/compiler/sql.py @@ -60,11 +60,15 @@ def compile_sql( prepared_stmt_map: Mapping[str, str], current_database: str, current_user: str, + allow_user_specified_id: Optional[bool], + apply_access_policies_sql: Optional[bool], ) -> List[dbstate.SQLQueryUnit]: opts = ResolverOptionsPartial( query_str=query_str, current_database=current_database, current_user=current_user, + allow_user_specified_id=allow_user_specified_id, + apply_access_policies_sql=apply_access_policies_sql, ) stmts = pg_parser.parse(query_str, propagate_spans=True) @@ -268,6 +272,8 @@ class ResolverOptionsPartial: current_user: str current_database: str query_str: str + allow_user_specified_id: Optional[bool] + apply_access_policies_sql: Optional[bool] def resolve_query( @@ -288,12 +294,16 @@ def resolve_query( allow_user_specified_id = lookup_bool_setting( tx_state, 'allow_user_specified_id' ) + if allow_user_specified_id is None: + allow_user_specified_id = opts.allow_user_specified_id if allow_user_specified_id is None: allow_user_specified_id = False apply_access_policies = lookup_bool_setting( tx_state, 'apply_access_policies_sql' ) + if apply_access_policies is None: + apply_access_policies = opts.apply_access_policies_sql if apply_access_policies is None: apply_access_policies = False @@ -317,15 +327,21 @@ def lookup_bool_setting( setting = tx_state.get(name) except KeyError: setting = None - if setting: - if isinstance(setting[0], str): - truthy = {'on', 'true', 'yes', '1'} - return setting[0].lower() in truthy - elif isinstance(setting[0], int): - return bool(setting[0]) + if setting and setting[0]: + return is_setting_truthy(setting[0]) return None +def is_setting_truthy(val: str | int | float) -> bool: + if isinstance(val, str): + truthy = {'on', 'true', 'yes', '1'} + return val.lower() in truthy + elif isinstance(val, int): + return bool(val) + else: + return False + + def compute_stmt_name(text: str, tx_state: dbstate.SQLTransactionState) -> str: stmt_hash = hashlib.sha1(text.encode("utf-8")) for setting_name in sorted(FE_SETTINGS_MUTABLE): diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index 66134e92ef1..db9579e1566 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -19,6 +19,7 @@ import csv import io import os.path +from typing import Optional import unittest import uuid @@ -39,6 +40,8 @@ class TestSQLQuery(tb.SQLQueryTestCase): os.path.dirname(__file__), 'schemas', 'inventory.esdl' ) + TRANSACTION_ISOLATION = False # needed for test_sql_query_set_04 + SETUP = [ ''' alter type novel { @@ -1069,6 +1072,78 @@ async def test_sql_query_set_03(self): res = await self.squery_values('SHOW search_path;') self.assertEqual(res, [["public"]]) + async def test_sql_query_set_04(self): + # database settings allow_user_specified_ids & apply_access_policies_sql + # should be unified over EdgeQL and SQL adapter + + async def set_current_database(val: Optional[bool]): + if val is None: + await self.con.execute( + f''' + configure current database + reset apply_access_policies_sql; + ''' + ) + else: + await self.con.execute( + f''' + configure current database + set apply_access_policies_sql := {str(val).lower()}; + ''' + ) + + async def set_sql(val: Optional[bool]): + if val is None: + await self.scon.execute( + f''' + RESET apply_access_policies_sql; + ''' + ) + else: + await self.scon.execute( + f''' + SET apply_access_policies_sql TO '{str(val).lower()}'; + ''' + ) + + async def are_policies_applied() -> bool: + res = await self.squery_values( + 'SELECT title FROM "Content" ORDER BY title' + ) + return len(res) == 0 + + await set_current_database(True) + await set_sql(True) + self.assertEqual(await are_policies_applied(), True) + + await set_sql(False) + self.assertEqual(await are_policies_applied(), False) + + await set_sql(None) + self.assertEqual(await are_policies_applied(), True) + + await set_current_database(False) + await set_sql(True) + self.assertEqual(await are_policies_applied(), True) + + await set_sql(False) + self.assertEqual(await are_policies_applied(), False) + + await set_sql(None) + self.assertEqual(await are_policies_applied(), False) + + await set_current_database(None) + await set_sql(True) + self.assertEqual(await are_policies_applied(), True) + + await set_sql(False) + self.assertEqual(await are_policies_applied(), False) + + await set_sql(None) + self.assertEqual(await are_policies_applied(), False) + + # setting cleanup not needed, since with end with the None, None + async def test_sql_query_static_eval_01(self): res = await self.squery_values('select current_schema;') self.assertEqual(res, [['public']]) @@ -1833,7 +1908,7 @@ async def test_sql_query_access_policy_03(self): ) with self.assertRaisesRegex( asyncpg.exceptions.InsufficientPrivilegeError, - 'access policy violation on insert of default::ContentSummary' + 'access policy violation on insert of default::ContentSummary', ): await self.scon.execute( 'INSERT INTO "ContentSummary" DEFAULT VALUES' @@ -1848,9 +1923,7 @@ async def test_sql_query_access_policy_04(self): await tran.start() # there is only one object that is of exactly type Content - res = await self.squery_values( - 'SELECT * FROM ONLY "Content"' - ) + res = await self.squery_values('SELECT * FROM ONLY "Content"') self.assertEqual(len(res), 1) await self.scon.execute('SET LOCAL apply_access_policies_sql TO true') @@ -1858,17 +1931,13 @@ async def test_sql_query_access_policy_04(self): await self.scon.execute( """SET LOCAL "global default::filter_title" TO 'Halo 3'""" ) - res = await self.squery_values( - 'SELECT * FROM ONLY "Content"' - ) + res = await self.squery_values('SELECT * FROM ONLY "Content"') self.assertEqual(len(res), 1) await self.scon.execute( """SET LOCAL "global default::filter_title" TO 'Forrest Gump'""" ) - res = await self.squery_values( - 'SELECT * FROM ONLY "Content"' - ) + res = await self.squery_values('SELECT * FROM ONLY "Content"') self.assertEqual(len(res), 0) await tran.rollback()