diff --git a/edb/pgsql/resolver/command.py b/edb/pgsql/resolver/command.py index fdeb7e2487c..f788072ea29 100644 --- a/edb/pgsql/resolver/command.py +++ b/edb/pgsql/resolver/command.py @@ -1418,7 +1418,7 @@ def _compile_uncompiled_dml( make_globals_empty=True, # TODO: globals in SQL singletons=singletons, anchors=anchors, - allow_user_specified_id=True, # TODO: should this be enabled? + allow_user_specified_id=ctx.options.allow_user_specified_id, ) ir_stmt = qlcompiler.compile_ast_to_ir( ql_stmt, diff --git a/edb/pgsql/resolver/context.py b/edb/pgsql/resolver/context.py index bcbad03398c..f5c8d8bd293 100644 --- a/edb/pgsql/resolver/context.py +++ b/edb/pgsql/resolver/context.py @@ -34,7 +34,7 @@ from edb.schema import pointers as s_pointers -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True, repr=False, match_args=False) class Options: current_database: str @@ -43,7 +43,10 @@ class Options: current_query: str # schemas that will be searched when idents don't have an explicit one - search_path: Sequence[str] = ("public",) + search_path: Sequence[str] + + # allow setting id in inserts + allow_user_specified_id: bool @dataclass(kw_only=True) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index b6ce3d3346a..584e7d2f08b 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -561,18 +561,30 @@ def parse_search_path(search_path_str: str) -> list[str]: def translate_query( stmt: pgast.Base ) -> Tuple[pg_codegen.SQLSource, Optional[dbstate.CommandCompleteTag]]: - args = {} + + search_path: Sequence[str] = ("public",) + allow_user_specified_id: bool = False + + try: + sp = tx_state.get("search_path") + except KeyError: + sp = None + if isinstance(sp, str): + search_path = parse_search_path(sp) + try: - search_path = tx_state.get("search_path") + allow_id = tx_state.get("allow_user_specified_id") except KeyError: - search_path = None - if isinstance(search_path, str): - args['search_path'] = parse_search_path(search_path) + allow_id = None + if isinstance(allow_id, str): + allow_user_specified_id = bool(allow_id) + options = pg_resolver.Options( current_user=current_user, current_database=current_database, current_query=query_str, - **args + search_path=search_path, + allow_user_specified_id=allow_user_specified_id ) resolved, complete_tag = pg_resolver.resolve( stmt, schema, options @@ -602,6 +614,7 @@ def compute_stmt_name(text: str) -> str: # frontend-only settings (key) and their mutability (value) fe_settings_mutable = { 'search_path': True, + 'allow_user_specified_id': True, 'server_version': False, 'server_version_num': False, } diff --git a/tests/test_sql_dml.py b/tests/test_sql_dml.py index 98fd9a40917..c4feb7b1df8 100644 --- a/tests/test_sql_dml.py +++ b/tests/test_sql_dml.py @@ -101,6 +101,8 @@ async def test_sql_dml_insert_02(self): # when columns are not specified, all columns are expected, # in alphabetical order: # id, __type__, owner, title + + await self.scon.execute("SET LOCAL allow_user_specified_id TO TRUE") with self.assertRaisesRegex( asyncpg.DataError, "cannot assign to link '__type__': it is protected", @@ -735,6 +737,8 @@ async def test_sql_dml_insert_34(self): id4 = uuid.uuid4() id5 = uuid.uuid4() + await self.scon.execute("SET LOCAL allow_user_specified_id TO TRUE") + res = await self.squery_values( f''' INSERT INTO "Document" (id) @@ -765,6 +769,30 @@ async def test_sql_dml_insert_34(self): ) self.assertEqual(res, [[id5]]) + async def test_sql_dml_insert_35(self): + with self.assertRaisesRegex( + asyncpg.exceptions.DataError, + "cannot assign to property 'id'", + ): + res = await self.squery_values( + f''' + INSERT INTO "Document" (id) VALUES ($1) RETURNING id + ''', + uuid.uuid4(), + ) + + await self.scon.execute( + 'SET LOCAL allow_user_specified_id TO TRUE' + ) + id = uuid.uuid4() + res = await self.squery_values( + f''' + INSERT INTO "Document" (id) VALUES ($1) RETURNING id + ''', + id, + ) + self.assertEqual(res, [[id]]) + async def test_sql_dml_delete_01(self): # delete, inspect CommandComplete tag