Skip to content

Commit

Permalink
SET allow_user_specified_id over SQL adapter (#7709)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Sep 5, 2024
1 parent 2703faa commit 9def2e6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 9 deletions.
2 changes: 1 addition & 1 deletion edb/pgsql/resolver/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def _compile_uncompiled_dml(
make_globals_empty=True, # TODO: globals in SQL
singletons=singletons,
anchors=anchors,
allow_user_specified_id=True, # TODO: should this be enabled?
allow_user_specified_id=ctx.options.allow_user_specified_id,
)
ir_stmt = qlcompiler.compile_ast_to_ir(
ql_stmt,
Expand Down
7 changes: 5 additions & 2 deletions edb/pgsql/resolver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from edb.schema import pointers as s_pointers


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True, repr=False, match_args=False)
class Options:
current_database: str

Expand All @@ -43,7 +43,10 @@ class Options:
current_query: str

# schemas that will be searched when idents don't have an explicit one
search_path: Sequence[str] = ("public",)
search_path: Sequence[str]

# allow setting id in inserts
allow_user_specified_id: bool


@dataclass(kw_only=True)
Expand Down
25 changes: 19 additions & 6 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,18 +561,30 @@ def parse_search_path(search_path_str: str) -> list[str]:
def translate_query(
stmt: pgast.Base
) -> Tuple[pg_codegen.SQLSource, Optional[dbstate.CommandCompleteTag]]:
args = {}

search_path: Sequence[str] = ("public",)
allow_user_specified_id: bool = False

try:
sp = tx_state.get("search_path")
except KeyError:
sp = None
if isinstance(sp, str):
search_path = parse_search_path(sp)

try:
search_path = tx_state.get("search_path")
allow_id = tx_state.get("allow_user_specified_id")
except KeyError:
search_path = None
if isinstance(search_path, str):
args['search_path'] = parse_search_path(search_path)
allow_id = None
if isinstance(allow_id, str):
allow_user_specified_id = bool(allow_id)

options = pg_resolver.Options(
current_user=current_user,
current_database=current_database,
current_query=query_str,
**args
search_path=search_path,
allow_user_specified_id=allow_user_specified_id
)
resolved, complete_tag = pg_resolver.resolve(
stmt, schema, options
Expand Down Expand Up @@ -602,6 +614,7 @@ def compute_stmt_name(text: str) -> str:
# frontend-only settings (key) and their mutability (value)
fe_settings_mutable = {
'search_path': True,
'allow_user_specified_id': True,
'server_version': False,
'server_version_num': False,
}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_sql_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ async def test_sql_dml_insert_02(self):
# when columns are not specified, all columns are expected,
# in alphabetical order:
# id, __type__, owner, title

await self.scon.execute("SET LOCAL allow_user_specified_id TO TRUE")
with self.assertRaisesRegex(
asyncpg.DataError,
"cannot assign to link '__type__': it is protected",
Expand Down Expand Up @@ -735,6 +737,8 @@ async def test_sql_dml_insert_34(self):
id4 = uuid.uuid4()
id5 = uuid.uuid4()

await self.scon.execute("SET LOCAL allow_user_specified_id TO TRUE")

res = await self.squery_values(
f'''
INSERT INTO "Document" (id)
Expand Down Expand Up @@ -765,6 +769,30 @@ async def test_sql_dml_insert_34(self):
)
self.assertEqual(res, [[id5]])

async def test_sql_dml_insert_35(self):
with self.assertRaisesRegex(
asyncpg.exceptions.DataError,
"cannot assign to property 'id'",
):
res = await self.squery_values(
f'''
INSERT INTO "Document" (id) VALUES ($1) RETURNING id
''',
uuid.uuid4(),
)

await self.scon.execute(
'SET LOCAL allow_user_specified_id TO TRUE'
)
id = uuid.uuid4()
res = await self.squery_values(
f'''
INSERT INTO "Document" (id) VALUES ($1) RETURNING id
''',
id,
)
self.assertEqual(res, [[id]])

async def test_sql_dml_delete_01(self):
# delete, inspect CommandComplete tag

Expand Down

0 comments on commit 9def2e6

Please sign in to comment.