Skip to content

Commit

Permalink
Unify SQL connection settings by falling back to database config (#7863)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Oct 17, 2024
1 parent 9839c76 commit 8319bbe
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 18 deletions.
4 changes: 4 additions & 0 deletions docs/stdlib/cfg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ Query behavior
UI session, so you won't have to remember to re-enable it when you're
done.

:eql:synopsis:`apply_access_policies_sql -> bool`
Determines whether access policies should be applied when running queries over
SQL adapter. Defaults to ``false``.

:eql:synopsis:`force_database_error -> str`
A hook to force all queries to produce an error. Defaults to 'false'.

Expand Down
8 changes: 8 additions & 0 deletions edb/lib/cfg.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ ALTER TYPE cfg::AbstractConfig {
'Whether access policies will be applied when running queries.';
};

CREATE PROPERTY apply_access_policies_sql -> std::bool {
SET default := false;
CREATE ANNOTATION cfg::affects_compilation := 'false';
CREATE ANNOTATION std::description :=
'Whether access policies will be applied when running queries over \
SQL adapter.';
};

CREATE PROPERTY allow_user_specified_id -> std::bool {
SET default := false;
CREATE ANNOTATION cfg::affects_compilation := 'true';
Expand Down
12 changes: 12 additions & 0 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,25 @@ def compile_sql(
)
schema = state.current_tx().get_schema(self.state.std_schema)

setting = database_config.get('allow_user_specified_id', None)
allow_user_specified_id = None
if setting and setting.value:
allow_user_specified_id = sql.is_setting_truthy(setting.value)

setting = database_config.get('apply_access_policies_sql', None)
apply_access_policies_sql = None
if setting and setting.value:
apply_access_policies_sql = sql.is_setting_truthy(setting.value)

return sql.compile_sql(
query_str,
schema=schema,
tx_state=tx_state,
prepared_stmt_map=prepared_stmt_map,
current_database=current_database,
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
)

def compile_request(
Expand Down
2 changes: 0 additions & 2 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,6 @@ class ParsedDatabase:
DEFAULT_SQL_SETTINGS: SQLSettings = immutables.Map()
DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map({
"search_path": ("public",),
"allow_user_specified_id": ("false",),
"apply_access_policies_sql": ("false",),
"server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)),
"server_version_num": cast(
SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,)
Expand Down
28 changes: 22 additions & 6 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,15 @@ def compile_sql(
prepared_stmt_map: Mapping[str, str],
current_database: str,
current_user: str,
allow_user_specified_id: Optional[bool],
apply_access_policies_sql: Optional[bool],
) -> List[dbstate.SQLQueryUnit]:
opts = ResolverOptionsPartial(
query_str=query_str,
current_database=current_database,
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
)

stmts = pg_parser.parse(query_str, propagate_spans=True)
Expand Down Expand Up @@ -268,6 +272,8 @@ class ResolverOptionsPartial:
current_user: str
current_database: str
query_str: str
allow_user_specified_id: Optional[bool]
apply_access_policies_sql: Optional[bool]


def resolve_query(
Expand All @@ -288,12 +294,16 @@ def resolve_query(
allow_user_specified_id = lookup_bool_setting(
tx_state, 'allow_user_specified_id'
)
if allow_user_specified_id is None:
allow_user_specified_id = opts.allow_user_specified_id
if allow_user_specified_id is None:
allow_user_specified_id = False

apply_access_policies = lookup_bool_setting(
tx_state, 'apply_access_policies_sql'
)
if apply_access_policies is None:
apply_access_policies = opts.apply_access_policies_sql
if apply_access_policies is None:
apply_access_policies = False

Expand All @@ -317,15 +327,21 @@ def lookup_bool_setting(
setting = tx_state.get(name)
except KeyError:
setting = None
if setting:
if isinstance(setting[0], str):
truthy = {'on', 'true', 'yes', '1'}
return setting[0].lower() in truthy
elif isinstance(setting[0], int):
return bool(setting[0])
if setting and setting[0]:
return is_setting_truthy(setting[0])
return None


def is_setting_truthy(val: str | int | float) -> bool:
if isinstance(val, str):
truthy = {'on', 'true', 'yes', '1'}
return val.lower() in truthy
elif isinstance(val, int):
return bool(val)
else:
return False


def compute_stmt_name(text: str, tx_state: dbstate.SQLTransactionState) -> str:
stmt_hash = hashlib.sha1(text.encode("utf-8"))
for setting_name in sorted(FE_SETTINGS_MUTABLE):
Expand Down
89 changes: 79 additions & 10 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import csv
import io
import os.path
from typing import Optional
import unittest
import uuid

Expand All @@ -39,6 +40,8 @@ class TestSQLQuery(tb.SQLQueryTestCase):
os.path.dirname(__file__), 'schemas', 'inventory.esdl'
)

TRANSACTION_ISOLATION = False # needed for test_sql_query_set_04

SETUP = [
'''
alter type novel {
Expand Down Expand Up @@ -1069,6 +1072,78 @@ async def test_sql_query_set_03(self):
res = await self.squery_values('SHOW search_path;')
self.assertEqual(res, [["public"]])

async def test_sql_query_set_04(self):
# database settings allow_user_specified_ids & apply_access_policies_sql
# should be unified over EdgeQL and SQL adapter

async def set_current_database(val: Optional[bool]):
if val is None:
await self.con.execute(
f'''
configure current database
reset apply_access_policies_sql;
'''
)
else:
await self.con.execute(
f'''
configure current database
set apply_access_policies_sql := {str(val).lower()};
'''
)

async def set_sql(val: Optional[bool]):
if val is None:
await self.scon.execute(
f'''
RESET apply_access_policies_sql;
'''
)
else:
await self.scon.execute(
f'''
SET apply_access_policies_sql TO '{str(val).lower()}';
'''
)

async def are_policies_applied() -> bool:
res = await self.squery_values(
'SELECT title FROM "Content" ORDER BY title'
)
return len(res) == 0

await set_current_database(True)
await set_sql(True)
self.assertEqual(await are_policies_applied(), True)

await set_sql(False)
self.assertEqual(await are_policies_applied(), False)

await set_sql(None)
self.assertEqual(await are_policies_applied(), True)

await set_current_database(False)
await set_sql(True)
self.assertEqual(await are_policies_applied(), True)

await set_sql(False)
self.assertEqual(await are_policies_applied(), False)

await set_sql(None)
self.assertEqual(await are_policies_applied(), False)

await set_current_database(None)
await set_sql(True)
self.assertEqual(await are_policies_applied(), True)

await set_sql(False)
self.assertEqual(await are_policies_applied(), False)

await set_sql(None)
self.assertEqual(await are_policies_applied(), False)

# setting cleanup not needed, since with end with the None, None

async def test_sql_query_static_eval_01(self):
res = await self.squery_values('select current_schema;')
self.assertEqual(res, [['public']])
Expand Down Expand Up @@ -1833,7 +1908,7 @@ async def test_sql_query_access_policy_03(self):
)
with self.assertRaisesRegex(
asyncpg.exceptions.InsufficientPrivilegeError,
'access policy violation on insert of default::ContentSummary'
'access policy violation on insert of default::ContentSummary',
):
await self.scon.execute(
'INSERT INTO "ContentSummary" DEFAULT VALUES'
Expand All @@ -1848,27 +1923,21 @@ async def test_sql_query_access_policy_04(self):
await tran.start()

# there is only one object that is of exactly type Content
res = await self.squery_values(
'SELECT * FROM ONLY "Content"'
)
res = await self.squery_values('SELECT * FROM ONLY "Content"')
self.assertEqual(len(res), 1)

await self.scon.execute('SET LOCAL apply_access_policies_sql TO true')

await self.scon.execute(
"""SET LOCAL "global default::filter_title" TO 'Halo 3'"""
)
res = await self.squery_values(
'SELECT * FROM ONLY "Content"'
)
res = await self.squery_values('SELECT * FROM ONLY "Content"')
self.assertEqual(len(res), 1)

await self.scon.execute(
"""SET LOCAL "global default::filter_title" TO 'Forrest Gump'"""
)
res = await self.squery_values(
'SELECT * FROM ONLY "Content"'
)
res = await self.squery_values('SELECT * FROM ONLY "Content"')
self.assertEqual(len(res), 0)

await tran.rollback()

0 comments on commit 8319bbe

Please sign in to comment.