Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename Config::apply_access_policies_sql to Config::apply_access_policies_pg #8075

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/reference/sql_adapter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ construct is mapped to PostgreSQL schema:
SET "global default::username" TO 'Tom'``.

- Access policies are applied to object type tables when setting
``apply_access_policies_sql`` is set to ``true``.
``apply_access_policies_pg`` is set to ``true``.

- Mutation rewrites and triggers are applied to all DML commands.

Expand Down Expand Up @@ -342,10 +342,10 @@ SQL adapter supports a limited subset of PostgreSQL connection settings.
There are the following additionally connection settings:

- ``allow_user_specified_id`` (default ``false``),
- ``apply_access_policies_sql`` (default ``false``),
- ``apply_access_policies_pg`` (default ``false``),
- settings prefixed with ``"global "`` can use used to set values of globals.

Note that if ``allow_user_specified_id`` or ``apply_access_policies_sql`` are
Note that if ``allow_user_specified_id`` or ``apply_access_policies_pg`` are
unset, they default to configuration set by ``configure current database``
EdgeQL command.

Expand Down
2 changes: 1 addition & 1 deletion docs/stdlib/cfg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Query behavior
UI session, so you won't have to remember to re-enable it when you're
done.

:eql:synopsis:`apply_access_policies_sql -> bool`
:eql:synopsis:`apply_access_policies_pg -> bool`
Determines whether access policies should be applied when running queries over
SQL adapter. Defaults to ``false``.

Expand Down
2 changes: 1 addition & 1 deletion edb/lib/cfg.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ ALTER TYPE cfg::AbstractConfig {
'Whether access policies will be applied when running queries.';
};

CREATE PROPERTY apply_access_policies_sql -> std::bool {
CREATE PROPERTY apply_access_policies_pg -> std::bool {
SET default := false;
CREATE ANNOTATION cfg::affects_compilation := 'false';
CREATE ANNOTATION std::description :=
Expand Down
10 changes: 5 additions & 5 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,10 @@ def compile_sql(
if setting and setting.value:
allow_user_specified_id = sql.is_setting_truthy(setting.value)

setting = database_config.get('apply_access_policies_sql', None)
apply_access_policies_sql = None
setting = database_config.get('apply_access_policies_pg', None)
apply_access_policies_pg = None
if setting and setting.value:
apply_access_policies_sql = sql.is_setting_truthy(setting.value)
apply_access_policies_pg = sql.is_setting_truthy(setting.value)

return sql.compile_sql(
query_str,
Expand All @@ -565,7 +565,7 @@ def compile_sql(
current_database=current_database,
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
apply_access_policies=apply_access_policies_pg,
disambiguate_column_names=False,
backend_runtime_params=self.state.backend_runtime_params,
protocol_version=defines.POSTGRES_PROTOCOL,
Expand Down Expand Up @@ -2516,7 +2516,7 @@ def compile_sql_as_unit_group(
current_database=ctx.branch_name or "<unknown>",
current_user=ctx.role_name or "<unknown>",
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies,
apply_access_policies=apply_access_policies,
include_edgeql_io_format_alternative=True,
allow_prepared_statements=False,
disambiguate_column_names=True,
Expand Down
18 changes: 9 additions & 9 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
{
'search_path': True,
'allow_user_specified_id': True,
'apply_access_policies_sql': True,
'apply_access_policies_pg': True,
'server_version': False,
'server_version_num': False,
}
Expand All @@ -66,7 +66,7 @@ def compile_sql(
current_database: str,
current_user: str,
allow_user_specified_id: Optional[bool],
apply_access_policies_sql: Optional[bool],
apply_access_policies: Optional[bool],
include_edgeql_io_format_alternative: bool = False,
allow_prepared_statements: bool = True,
disambiguate_column_names: bool,
Expand All @@ -78,7 +78,7 @@ def compile_sql(
current_database=current_database,
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
apply_access_policies=apply_access_policies,
include_edgeql_io_format_alternative=(
include_edgeql_io_format_alternative
),
Expand Down Expand Up @@ -329,10 +329,10 @@ def compile_sql(
'allow_user_specified_id',
('true' if allow_user_specified_id else 'false',),
)
if apply_access_policies_sql is not None:
if apply_access_policies is not None:
cconfig.setdefault(
'apply_access_policies_sql',
('true' if apply_access_policies_sql else 'false',),
'apply_access_policies',
('true' if apply_access_policies else 'false',),
)
search_path = parse_search_path(cconfig.pop("search_path", ("",)))
cconfig = dict(sorted((k, v) for k, v in cconfig.items()))
Expand Down Expand Up @@ -389,7 +389,7 @@ class ResolverOptionsPartial:
current_database: str
query_str: str
allow_user_specified_id: Optional[bool]
apply_access_policies_sql: Optional[bool]
apply_access_policies: Optional[bool]
include_edgeql_io_format_alternative: Optional[bool]
disambiguate_column_names: bool

Expand Down Expand Up @@ -422,10 +422,10 @@ def resolve_query(
allow_user_specified_id = False

apply_access_policies = lookup_bool_setting(
tx_state, 'apply_access_policies_sql'
tx_state, 'apply_access_policies_pg'
)
if apply_access_policies is None:
apply_access_policies = opts.apply_access_policies_sql
apply_access_policies = opts.apply_access_policies
if apply_access_policies is None:
apply_access_policies = False

Expand Down
13 changes: 12 additions & 1 deletion edb/server/protocol/pg_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ cdef class PgConnection(frontend.FrontendConnection):
if self.debug:
self.debug_print("Compile", query_str)
fe_settings = dbv.current_fe_settings()
key = (hashlib.sha1(query_str.encode("utf-8")).digest(), fe_settings)
key = compute_cache_key(query_str, fe_settings)

ignore_cache |= self._disable_cache

Expand Down Expand Up @@ -1584,6 +1584,17 @@ cdef class PgConnection(frontend.FrontendConnection):
return qu


def compute_cache_key(
query_str: str, fe_settings: dbstate.SQLSettings
) -> bytes:
h = hashlib.blake2b(query_str.encode("utf-8"))
for key, value in fe_settings.items():
if key.startswith('global '):
continue
h.update(hash(value).to_bytes(8, signed=True))
return h.digest()


cdef WriteBuffer remap_arguments(
data: bytes,
params: list[dbstate.SQLParam] | None,
Expand Down
12 changes: 12 additions & 0 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,6 +2269,18 @@ async def connect(self, **kwargs: Any) -> tconn.Connection:
conn_args = self.get_connect_args(**kwargs)
return await tconn.async_connect_test_client(**conn_args)

async def connect_pg(self, **kwargs: Any) -> asyncpg.Connection:
import asyncpg

conn_args = self.get_connect_args(**kwargs)
return await asyncpg.connect(
host=conn_args['host'],
port=conn_args['port'],
user=conn_args['user'],
password=conn_args['password'],
ssl='require'
)

async def connect_test_protocol(self, **kwargs):
conn_args = self.get_connect_args(**kwargs)
conn = await test_protocol.new_connection(**conn_args)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_server_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,14 @@ def measure_compilations(
'{tenant="localtest",path="compiler"}'
) or 0

def measure_sql_compilations(
sd: tb._EdgeDBServerData
) -> Callable[[], float | int]:
return lambda: tb.parse_metrics(sd.fetch_metrics()).get(
'edgedb_server_sql_compilations_total'
'{tenant="localtest"}'
) or 0

with tempfile.TemporaryDirectory() as temp_dir:
async with tb.start_edgedb_server(
data_dir=temp_dir,
Expand Down Expand Up @@ -822,6 +830,42 @@ def measure_compilations(
finally:
await con.aclose()

has_asyncpg = True
try:
import asyncpg # noqa
except ImportError:
has_asyncpg = False

if has_asyncpg:
scon = await sd.connect_pg()
try:
with self.assertChange(measure_sql_compilations(sd), 1):
await scon.fetch('select 1')

with self.assertChange(measure_sql_compilations(sd), 1):
await scon.fetch('select 1 + 1')

# cache hit
with self.assertChange(measure_sql_compilations(sd), 0):
await scon.fetch('select 1')

# TODO: normalization & constant extraction
with self.assertChange(measure_sql_compilations(sd), 2):
await scon.fetch('select 2')
await scon.fetch('sELEcT 1')

# cache hit, even after global has been changed
await scon.execute('SET "global default::g" to 1')
with self.assertChange(measure_sql_compilations(sd), 0):
await scon.execute('select 1')

# compiler call, because config was changed
await scon.execute('SET apply_access_policies_pg to 1')
with self.assertChange(measure_sql_compilations(sd), 1):
await scon.execute('select 1')
finally:
await scon.close()

# Now restart the server to test the cache persistence.
async with tb.start_edgedb_server(
data_dir=temp_dir,
Expand Down
22 changes: 11 additions & 11 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ async def test_sql_query_33a(self):
# system columns when access policies are applied
tran = self.scon.transaction()
await tran.start()
await self.scon.execute('SET LOCAL apply_access_policies_sql TO true')
await self.scon.execute('SET LOCAL apply_access_policies_pg TO true')
await self.scon.execute(
"""SET LOCAL "global default::filter_title" TO 'Halo 3'"""
)
Expand Down Expand Up @@ -1353,36 +1353,36 @@ async def test_sql_query_set_03(self):
self.assertEqual(res, [["public"]])

async def test_sql_query_set_04(self):
# database settings allow_user_specified_ids & apply_access_policies_sql
# database settings allow_user_specified_ids & apply_access_policies_pg
# should be unified over EdgeQL and SQL adapter

async def set_current_database(val: Optional[bool]):
if val is None:
await self.con.execute(
f'''
configure current database
reset apply_access_policies_sql;
reset apply_access_policies_pg;
'''
)
else:
await self.con.execute(
f'''
configure current database
set apply_access_policies_sql := {str(val).lower()};
set apply_access_policies_pg := {str(val).lower()};
'''
)

async def set_sql(val: Optional[bool]):
if val is None:
await self.scon.execute(
f'''
RESET apply_access_policies_sql;
RESET apply_access_policies_pg;
'''
)
else:
await self.scon.execute(
f'''
SET apply_access_policies_sql TO '{str(val).lower()}';
SET apply_access_policies_pg TO '{str(val).lower()}';
'''
)

Expand Down Expand Up @@ -2150,7 +2150,7 @@ async def test_sql_query_access_policy_01(self):
],
)

await self.scon.execute('SET LOCAL apply_access_policies_sql TO true')
await self.scon.execute('SET LOCAL apply_access_policies_pg TO true')

# access policies applied
res = await self.squery_values(
Expand Down Expand Up @@ -2179,7 +2179,7 @@ async def test_sql_query_access_policy_02(self):
res = await self.squery_values('SELECT x FROM "ContentSummary"')
self.assertEqual(res, [[5]])

await self.scon.execute('SET LOCAL apply_access_policies_sql TO true')
await self.scon.execute('SET LOCAL apply_access_policies_pg TO true')

# access policies applied
res = await self.squery_values('SELECT x FROM "ContentSummary"')
Expand All @@ -2202,7 +2202,7 @@ async def test_sql_query_access_policy_03(self):

# allowed without applying access policies

await self.scon.execute('SET LOCAL apply_access_policies_sql TO true')
await self.scon.execute('SET LOCAL apply_access_policies_pg TO true')

# allowed when filter_title == 'summary'
await self.scon.execute(
Expand Down Expand Up @@ -2233,7 +2233,7 @@ async def test_sql_query_access_policy_04(self):
res = await self.squery_values('SELECT * FROM ONLY "Content"')
self.assertEqual(len(res), 1)

await self.scon.execute('SET LOCAL apply_access_policies_sql TO true')
await self.scon.execute('SET LOCAL apply_access_policies_pg TO true')

await self.scon.execute(
"""SET LOCAL "global default::filter_title" TO 'Halo 3'"""
Expand Down Expand Up @@ -2313,7 +2313,7 @@ async def test_sql_query_locking_01(self):
"locking clause not supported",
):
await self.scon.execute(
'SET LOCAL apply_access_policies_sql TO TRUE'
'SET LOCAL apply_access_policies_pg TO TRUE'
)
await self.squery_values(
'''
Expand Down
Loading