From 15bedf9a73895b28d906f619ecf2c7c0c6a6acb4 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Mon, 16 Sep 2024 11:54:33 -0400 Subject: [PATCH] Search std for module name when using `with module`. --- edb/schema/schema.py | 36 ++++-- tests/test_edgeql_expressions.py | 197 +++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 9 deletions(-) diff --git a/edb/schema/schema.py b/edb/schema/schema.py index 771b21ffa854..c1c8b8fe02c4 100644 --- a/edb/schema/schema.py +++ b/edb/schema/schema.py @@ -1130,17 +1130,35 @@ def _search_with_getter( if not local and ( orig_module is None or ( - not alias_hit and module and not ( - self.has_module(fmod := module.split('::')[0]) - or (disallow_module and disallow_module(fmod)) - ) + not alias_hit and module ) ): - mod_name = 'std' if orig_module is None else f'std::{orig_module}' - fqname = sn.QualName(mod_name, shortname) - result = getter(self, fqname) - if result is not None: - return result + # If no module was specified, look in std + if orig_module is None: + mod_name = 'std' + fqname = sn.QualName(mod_name, shortname) + result = getter(self, fqname) + if result is not None: + return result + + # If a module was specified in the name, ensure that no base module + # of the same name exists. + # + # If no module was specified, try the default module name as a part + # of std. The same condition applies. + if module and not ( + self.has_module(fmod := module.split('::')[0]) + or (disallow_module and disallow_module(fmod)) + ): + mod_name = ( + f'std::{module}' + if orig_module is None + else f'std::{orig_module}' + ) + fqname = sn.QualName(mod_name, shortname) + result = getter(self, fqname) + if result is not None: + return result return default diff --git a/tests/test_edgeql_expressions.py b/tests/test_edgeql_expressions.py index 97ec599a1cb4..0804020c1d1e 100644 --- a/tests/test_edgeql_expressions.py +++ b/tests/test_edgeql_expressions.py @@ -9905,3 +9905,200 @@ async def test_edgeql_cast_to_function_01(self): await self.con.execute(f""" select 1; """) + + async def test_edgeql_expr_with_module_01(self): + await self.con.execute(f""" + create module dummy; + """) + + valid_queries = [ + 'SELECT {} = 1;', + 'SELECT {} = 1;', + 'WITH MODULE dummy SELECT {} = 1;', + 'WITH MODULE dummy SELECT {} = 1;', + 'WITH MODULE std SELECT {} = 1;', + 'WITH MODULE std SELECT {} = 1;', + ] + + for query in valid_queries: + await self.con.execute(query) + + async def test_edgeql_expr_with_module_02(self): + await self.con.execute(f""" + create module dummy; + create type default::int64; + """) + + valid_queries = [ + 'SELECT {} = 1;', + 'WITH MODULE dummy SELECT {} = 1;', + 'WITH MODULE dummy SELECT {} = 1;', + 'WITH MODULE std SELECT {} = 1;', + 'WITH MODULE std SELECT {} = 1;', + ] + invalid_queries = [ + 'SELECT {} = 1;', + 'SELECT {} = 1;', + 'WITH MODULE dummy SELECT {} = 1;', + 'WITH MODULE std SELECT {} = 1;', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidTypeError, + "operator '=' cannot be applied", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_03(self): + await self.con.execute(f""" + create module dummy; + create module std::test; + create scalar type std::test::Foo extending int64; + """) + + valid_queries = [ + 'select 1;', + 'select 1;', + 'with module dummy select 1;', + 'with module dummy select 1;', + 'with module test select 1;', + 'with module test select 1;', + 'with module test select 1;', + 'with module std select 1;', + 'with module std select 1;', + 'with module std::test select 1;', + 'with module std::test select 1;', + 'with module std::test select 1;', + ] + invalid_queries = [ + 'select 1;', + 'with module dummy select 1;', + 'with module std select 1;', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "Foo' does not exist", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_04(self): + await self.con.execute(f""" + create module dummy; + create module std::test; + create scalar type std::test::Foo extending int64; + create module test; + """) + + valid_queries = [ + 'select 1;', + 'with module dummy select 1;', + 'with module test select 1;', + 'with module std select 1;', + 'with module std::test select 1;', + 'with module std::test select 1;', + ] + invalid_queries = [ + 'select 1;', + 'select 1;', + 'with module dummy select 1;', + 'with module dummy select 1;', + 'with module test select 1;', + 'with module test select 1;', + 'with module std select 1;', + 'with module std select 1;', + 'with module std::test select 1;', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "Foo' does not exist", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_05(self): + await self.con.execute(f""" + create module dummy; + create module std::test; + create function std::test::Foo(x: int64) -> int64 using(x); + """) + + valid_queries = [ + 'select test::Foo(1);', + 'select std::test::Foo(1);', + 'with module dummy select test::Foo(1);', + 'with module dummy select std::test::Foo(1);', + 'with module test select Foo(1);', + 'with module test select test::Foo(1);', + 'with module test select std::test::Foo(1);', + 'with module std select test::Foo(1);', + 'with module std select std::test::Foo(1);', + 'with module std::test select Foo(1);', + 'with module std::test select test::Foo(1);', + 'with module std::test select std::test::Foo(1);', + ] + invalid_queries = [ + 'select Foo(1);', + 'with module dummy select Foo(1);', + 'with module std select Foo(1);', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "Foo' does not exist", + ): + await self.con.execute(query) + + async def test_edgeql_expr_with_module_06(self): + await self.con.execute(f""" + create module dummy; + create module std::test; + create function std::test::Foo(x: int64) -> int64 using(x); + create module test; + """) + + valid_queries = [ + 'select std::test::Foo(1);', + 'with module dummy select std::test::Foo(1);', + 'with module test select std::test::Foo(1);', + 'with module std select std::test::Foo(1);', + 'with module std::test select Foo(1);', + 'with module std::test select std::test::Foo(1);', + ] + invalid_queries = [ + 'select Foo(1);', + 'select test::Foo(1);', + 'with module dummy select Foo(1);', + 'with module dummy select test::Foo(1);', + 'with module test select Foo(1);', + 'with module test select test::Foo(1);', + 'with module std select Foo(1);', + 'with module std select test::Foo(1);', + 'with module std::test select test::Foo(1);', + ] + + for query in valid_queries: + await self.con.execute(query) + + for query in invalid_queries: + async with self.assertRaisesRegexTx( + edgedb.errors.InvalidReferenceError, + "Foo' does not exist", + ): + await self.con.execute(query)