diff --git a/edb/schema/schema.py b/edb/schema/schema.py index e83e9eefc99..f8696c4f82b 100644 --- a/edb/schema/schema.py +++ b/edb/schema/schema.py @@ -1131,17 +1131,35 @@ def _search_with_getter( if not local and ( orig_module is None or ( - not alias_hit and module and not ( - self.has_module(fmod := module.split('::')[0]) - or (disallow_module and disallow_module(fmod)) - ) + not alias_hit and module ) ): - mod_name = 'std' if orig_module is None else f'std::{orig_module}' - fqname = sn.QualName(mod_name, shortname) - result = getter(self, fqname) - if result is not None: - return result + # If no module was specified, look in std + if orig_module is None: + mod_name = 'std' + fqname = sn.QualName(mod_name, shortname) + result = getter(self, fqname) + if result is not None: + return result + + # If a module was specified in the name, ensure that no base module + # of the same name exists. + # + # If no module was specified, try the default module name as a part + # of std. The same condition applies. + if module and not ( + self.has_module(fmod := module.split('::')[0]) + or (disallow_module and disallow_module(fmod)) + ): + mod_name = ( + f'std::{module}' + if orig_module is None + else f'std::{orig_module}' + ) + fqname = sn.QualName(mod_name, shortname) + result = getter(self, fqname) + if result is not None: + return result return default diff --git a/tests/test_edgeql_expressions.py b/tests/test_edgeql_expressions.py index 97ec599a1cb..ad511a51b84 100644 --- a/tests/test_edgeql_expressions.py +++ b/tests/test_edgeql_expressions.py @@ -9905,3 +9905,196 @@ async def test_edgeql_cast_to_function_01(self): await self.con.execute(f""" select 1; """) + + async def test_edgeql_expr_with_module_01(self): + await self.con.execute(f""" + create module dummy; + """) + + valid_queries = [ + 'SELECT {} = 1', + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + + for query in valid_queries: + await self.con.execute(query) + + async def test_edgeql_expr_with_module_02(self): + await self.con.execute(f""" + create module dummy; + create type default::int64; + """) + + valid_queries = [ + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + invalid_queries = [ + 'SELECT {} = 1', + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidTypeError, + "operator '=' cannot be applied", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_03(self): + await self.con.execute(f""" + create module dummy; + """) + + valid_queries = [ + 'select _test::abs(1)', + 'select std::_test::abs(1)', + 'with module dummy select _test::abs(1)', + 'with module dummy select std::_test::abs(1)', + 'with module _test select abs(1)', + 'with module _test select _test::abs(1)', + 'with module _test select std::_test::abs(1)', + 'with module std select _test::abs(1)', + 'with module std select std::_test::abs(1)', + 'with module std::_test select abs(1)', + 'with module std::_test select _test::abs(1)', + 'with module std::_test select std::_test::abs(1)', + ] + invalid_queries = [ + 'select abs(1)', + 'with module dummy select abs(1)', + 'with module std select abs(1)', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "abs' does not exist", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_04(self): + await self.con.execute(f""" + create module dummy; + create module _test; + """) + + valid_queries = [ + 'select std::_test::abs(1)', + 'with module dummy select std::_test::abs(1)', + 'with module _test select std::_test::abs(1)', + 'with module std select std::_test::abs(1)', + 'with module std::_test select abs(1)', + 'with module std::_test select std::_test::abs(1)', + ] + invalid_queries = [ + 'select abs(1)', + 'select _test::abs(1)', + 'with module dummy select abs(1)', + 'with module dummy select _test::abs(1)', + 'with module _test select abs(1)', + 'with module _test select _test::abs(1)', + 'with module std select abs(1)', + 'with module std select _test::abs(1)', + 'with module std::_test select _test::abs(1)', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "abs' does not exist", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_05(self): + await self.con.execute(f""" + create module dummy; + create module std::test; + create scalar type std::test::Foo extending int64; + """) + + valid_queries = [ + 'select 1', + 'select 1', + 'with module dummy select 1', + 'with module dummy select 1', + 'with module test select 1', + 'with module test select 1', + 'with module test select 1', + 'with module std select 1', + 'with module std select 1', + 'with module std::test select 1', + 'with module std::test select 1', + 'with module std::test select 1', + ] + invalid_queries = [ + 'select 1', + 'with module dummy select 1', + 'with module std select 1', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "Foo' does not exist", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_06(self): + await self.con.execute(f""" + create module dummy; + create module std::test; + create scalar type std::test::Foo extending int64; + create module test; + """) + + valid_queries = [ + 'select 1', + 'with module dummy select 1', + 'with module test select 1', + 'with module std select 1', + 'with module std::test select 1', + 'with module std::test select 1', + ] + invalid_queries = [ + 'select 1', + 'select 1', + 'with module dummy select 1', + 'with module dummy select 1', + 'with module test select 1', + 'with module test select 1', + 'with module std select 1', + 'with module std select 1', + 'with module std::test select 1', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "Foo' does not exist", + ): + await self.con.execute(query) diff --git a/tests/test_schema.py b/tests/test_schema.py index 8300a1927b5..a3e08a890c8 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Type, TYPE_CHECKING import random import re @@ -3083,6 +3083,146 @@ def test_schema_unknown_typename_05(self): } """ + def _run_migration_to(self, schema_text: str) -> None: + migration_text = f''' + START MIGRATION TO {{ + {schema_text} + }}; + POPULATE MIGRATION; + COMMIT MIGRATION; + ''' + + self.run_ddl(self.schema, migration_text) + + def _check_valid_queries( + self, + schema_text: str, + valid_queries: list[str], + ) -> None: + + for query in valid_queries: + query_text = f''' + module default {{ alias query := ({query}); }} + ''' + self._run_migration_to(schema_text + query_text) + + def _check_invalid_queries( + self, + schema_text: str, + invalid_queries: list[str], + error_type: Type, + error_message: str, + ) -> None: + for query in invalid_queries: + query_text = f''' + module default {{ alias query := ({query}); }} + ''' + with self.assertRaisesRegex(error_type, error_message): + self._run_migration_to(schema_text + query_text) + + def test_schema_with_module_01(self): + schema_text = f''' + module dummy {{}} + ''' + valid_queries = [ + 'SELECT {} = 1', + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + self._check_valid_queries(schema_text, valid_queries) + + def test_schema_with_module_02(self): + schema_text = f''' + module dummy {{}} + module default {{ type int64; }} + ''' + valid_queries = [ + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + invalid_queries = [ + 'SELECT {} = 1', + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + self._check_valid_queries(schema_text, valid_queries) + self._check_invalid_queries( + schema_text, + invalid_queries, + errors.InvalidTypeError, + "operator '=' cannot be applied", + ) + + def test_schema_with_module_03(self): + schema_text = f''' + module dummy {{}} + ''' + valid_queries = [ + 'select _test::abs(1)', + 'select std::_test::abs(1)', + 'with module dummy select _test::abs(1)', + 'with module dummy select std::_test::abs(1)', + 'with module _test select abs(1)', + 'with module _test select _test::abs(1)', + 'with module _test select std::_test::abs(1)', + 'with module std select _test::abs(1)', + 'with module std select std::_test::abs(1)', + 'with module std::_test select abs(1)', + 'with module std::_test select _test::abs(1)', + 'with module std::_test select std::_test::abs(1)', + ] + invalid_queries = [ + 'select abs(1)', + 'with module dummy select abs(1)', + 'with module std select abs(1)', + ] + self._check_valid_queries(schema_text, valid_queries) + self._check_invalid_queries( + schema_text, + invalid_queries, + errors.InvalidReferenceError, + "abs' does not exist", + ) + + def test_schema_with_module_04(self): + schema_text = f''' + module dummy {{}} + module _test {{}} + ''' + valid_queries = [ + 'select std::_test::abs(1)', + 'with module dummy select std::_test::abs(1)', + 'with module _test select std::_test::abs(1)', + 'with module std select std::_test::abs(1)', + 'with module std::_test select abs(1)', + 'with module std::_test select std::_test::abs(1)', + ] + invalid_queries = [ + 'select abs(1)', + 'select _test::abs(1)', + 'with module dummy select abs(1)', + 'with module dummy select _test::abs(1)', + 'with module _test select abs(1)', + 'with module _test select _test::abs(1)', + 'with module std select abs(1)', + 'with module std select _test::abs(1)', + 'with module std::_test select _test::abs(1)', + ] + self._check_valid_queries(schema_text, valid_queries) + self._check_invalid_queries( + schema_text, + invalid_queries, + errors.InvalidReferenceError, + "abs' does not exist", + ) + class TestGetMigration(tb.BaseSchemaLoadTest): """Test migration deparse consistency. @@ -11104,6 +11244,149 @@ def test_schema_describe_overload_01(self): """, ) + def test_schema_describe_with_module_01(self): + schema_text = f''' + module dummy {{}} + ''' + valid_queries = [ + 'SELECT {} = 1', + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + normalized = 'SELECT ({} = 1)' + for query in valid_queries: + self._assert_describe( + schema_text + f''' + module default {{ alias query := ({query}); }} + ''', + + 'describe module default as sdl', + + f''' + alias default::query := ({normalized}); + ''', + explicit_modules=True, + ) + + def test_schema_describe_with_module_02a(self): + schema_text = f''' + module dummy {{}} + module default {{ type int64; }} + ''' + valid_queries = [ + 'SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE dummy SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + 'WITH MODULE std SELECT {} = 1', + ] + normalized = 'SELECT ({} = 1)' + for query in valid_queries: + self._assert_describe( + schema_text + f''' + module default {{ alias query := ({query}); }} + ''', + + 'describe module default as sdl', + + f''' + alias default::query := ({normalized}); + type default::int64; + ''', + explicit_modules=True, + ) + + def test_schema_describe_with_module_02b(self): + schema_text = f''' + module dummy {{}} + module default {{ type int64; }} + ''' + valid_queries = [ + 'SELECT {}', + 'SELECT {}', + 'WITH MODULE dummy SELECT {}', + 'WITH MODULE std SELECT {}', + ] + normalized = 'SELECT {}' + for query in valid_queries: + self._assert_describe( + schema_text + f''' + module default {{ alias query := ({query}); }} + ''', + + 'describe module default as sdl', + + f''' + alias default::query := ({normalized}); + type default::int64; + ''', + explicit_modules=True, + ) + + def test_schema_describe_with_module_03(self): + schema_text = f''' + module dummy {{}} + ''' + valid_queries = [ + 'select _test::abs(1)', + 'select std::_test::abs(1)', + 'with module dummy select _test::abs(1)', + 'with module dummy select std::_test::abs(1)', + 'with module _test select abs(1)', + 'with module _test select _test::abs(1)', + 'with module _test select std::_test::abs(1)', + 'with module std select _test::abs(1)', + 'with module std select std::_test::abs(1)', + 'with module std::_test select abs(1)', + 'with module std::_test select _test::abs(1)', + 'with module std::_test select std::_test::abs(1)', + ] + normalized = 'SELECT std::_test::abs(1)' + for query in valid_queries: + self._assert_describe( + schema_text + f''' + module default {{ alias query := ({query}); }} + ''', + + 'describe module default as sdl', + + f''' + alias default::query := ({normalized}); + ''', + explicit_modules=True, + ) + + def test_schema_describe_with_module_04(self): + schema_text = f''' + module dummy {{}} + module _test {{}} + ''' + valid_queries = [ + 'select std::_test::abs(1)', + 'with module dummy select std::_test::abs(1)', + 'with module _test select std::_test::abs(1)', + 'with module std select std::_test::abs(1)', + 'with module std::_test select abs(1)', + 'with module std::_test select std::_test::abs(1)', + ] + normalized = 'SELECT std::_test::abs(1)' + for query in valid_queries: + self._assert_describe( + schema_text + f''' + module default {{ alias query := ({query}); }} + ''', + + 'describe module default as sdl', + + f''' + alias default::query := ({normalized}); + ''', + explicit_modules=True, + ) + class TestSDLTextFromSchema(BaseDescribeTest):