Skip to content

Commit

Permalink
Search std for module name when using with module.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Sep 16, 2024
1 parent 6ba3152 commit 15bedf9
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 9 deletions.
36 changes: 27 additions & 9 deletions edb/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,17 +1130,35 @@ def _search_with_getter(
if not local and (
orig_module is None
or (
not alias_hit and module and not (
self.has_module(fmod := module.split('::')[0])
or (disallow_module and disallow_module(fmod))
)
not alias_hit and module
)
):
mod_name = 'std' if orig_module is None else f'std::{orig_module}'
fqname = sn.QualName(mod_name, shortname)
result = getter(self, fqname)
if result is not None:
return result
# If no module was specified, look in std
if orig_module is None:
mod_name = 'std'
fqname = sn.QualName(mod_name, shortname)
result = getter(self, fqname)
if result is not None:
return result

# If a module was specified in the name, ensure that no base module
# of the same name exists.
#
# If no module was specified, try the default module name as a part
# of std. The same condition applies.
if module and not (
self.has_module(fmod := module.split('::')[0])
or (disallow_module and disallow_module(fmod))
):
mod_name = (
f'std::{module}'
if orig_module is None
else f'std::{orig_module}'
)
fqname = sn.QualName(mod_name, shortname)
result = getter(self, fqname)
if result is not None:
return result

return default

Expand Down
197 changes: 197 additions & 0 deletions tests/test_edgeql_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9905,3 +9905,200 @@ async def test_edgeql_cast_to_function_01(self):
await self.con.execute(f"""
select <cal::to_local_date>1;
""")

async def test_edgeql_expr_with_module_01(self):
await self.con.execute(f"""
create module dummy;
""")

valid_queries = [
'SELECT <int64>{} = 1;',
'SELECT <std::int64>{} = 1;',
'WITH MODULE dummy SELECT <int64>{} = 1;',
'WITH MODULE dummy SELECT <std::int64>{} = 1;',
'WITH MODULE std SELECT <int64>{} = 1;',
'WITH MODULE std SELECT <std::int64>{} = 1;',
]

for query in valid_queries:
await self.con.execute(query)

async def test_edgeql_expr_with_module_02(self):
await self.con.execute(f"""
create module dummy;
create type default::int64;
""")

valid_queries = [
'SELECT <std::int64>{} = 1;',
'WITH MODULE dummy SELECT <int64>{} = 1;',
'WITH MODULE dummy SELECT <std::int64>{} = 1;',
'WITH MODULE std SELECT <int64>{} = 1;',
'WITH MODULE std SELECT <std::int64>{} = 1;',
]
invalid_queries = [
'SELECT <int64>{} = 1;',
'SELECT <default::int64>{} = 1;',
'WITH MODULE dummy SELECT <default::int64>{} = 1;',
'WITH MODULE std SELECT <default::int64>{} = 1;',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidTypeError,
"operator '=' cannot be applied",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_03(self):
await self.con.execute(f"""
create module dummy;
create module std::test;
create scalar type std::test::Foo extending int64;
""")

valid_queries = [
'select <test::Foo>1;',
'select <std::test::Foo>1;',
'with module dummy select <test::Foo>1;',
'with module dummy select <std::test::Foo>1;',
'with module test select <Foo>1;',
'with module test select <test::Foo>1;',
'with module test select <std::test::Foo>1;',
'with module std select <test::Foo>1;',
'with module std select <std::test::Foo>1;',
'with module std::test select <Foo>1;',
'with module std::test select <test::Foo>1;',
'with module std::test select <std::test::Foo>1;',
]
invalid_queries = [
'select <Foo>1;',
'with module dummy select <Foo>1;',
'with module std select <Foo>1;',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"Foo' does not exist",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_04(self):
await self.con.execute(f"""
create module dummy;
create module std::test;
create scalar type std::test::Foo extending int64;
create module test;
""")

valid_queries = [
'select <std::test::Foo>1;',
'with module dummy select <std::test::Foo>1;',
'with module test select <std::test::Foo>1;',
'with module std select <std::test::Foo>1;',
'with module std::test select <Foo>1;',
'with module std::test select <std::test::Foo>1;',
]
invalid_queries = [
'select <Foo>1;',
'select <test::Foo>1;',
'with module dummy select <Foo>1;',
'with module dummy select <test::Foo>1;',
'with module test select <Foo>1;',
'with module test select <test::Foo>1;',
'with module std select <Foo>1;',
'with module std select <test::Foo>1;',
'with module std::test select <test::Foo>1;',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"Foo' does not exist",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_05(self):
await self.con.execute(f"""
create module dummy;
create module std::test;
create function std::test::Foo(x: int64) -> int64 using(x);
""")

valid_queries = [
'select test::Foo(1);',
'select std::test::Foo(1);',
'with module dummy select test::Foo(1);',
'with module dummy select std::test::Foo(1);',
'with module test select Foo(1);',
'with module test select test::Foo(1);',
'with module test select std::test::Foo(1);',
'with module std select test::Foo(1);',
'with module std select std::test::Foo(1);',
'with module std::test select Foo(1);',
'with module std::test select test::Foo(1);',
'with module std::test select std::test::Foo(1);',
]
invalid_queries = [
'select Foo(1);',
'with module dummy select Foo(1);',
'with module std select Foo(1);',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"Foo' does not exist",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_06(self):
await self.con.execute(f"""
create module dummy;
create module std::test;
create function std::test::Foo(x: int64) -> int64 using(x);
create module test;
""")

valid_queries = [
'select std::test::Foo(1);',
'with module dummy select std::test::Foo(1);',
'with module test select std::test::Foo(1);',
'with module std select std::test::Foo(1);',
'with module std::test select Foo(1);',
'with module std::test select std::test::Foo(1);',
]
invalid_queries = [
'select Foo(1);',
'select test::Foo(1);',
'with module dummy select Foo(1);',
'with module dummy select test::Foo(1);',
'with module test select Foo(1);',
'with module test select test::Foo(1);',
'with module std select Foo(1);',
'with module std select test::Foo(1);',
'with module std::test select test::Foo(1);',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"Foo' does not exist",
):
await self.con.execute(query)

0 comments on commit 15bedf9

Please sign in to comment.