Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search std for module name when using with module. #7753

Merged
merged 3 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions edb/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,17 +1131,35 @@ def _search_with_getter(
if not local and (
orig_module is None
or (
not alias_hit and module and not (
self.has_module(fmod := module.split('::')[0])
or (disallow_module and disallow_module(fmod))
)
not alias_hit and module
)
):
mod_name = 'std' if orig_module is None else f'std::{orig_module}'
fqname = sn.QualName(mod_name, shortname)
result = getter(self, fqname)
if result is not None:
return result
# If no module was specified, look in std
if orig_module is None:
mod_name = 'std'
fqname = sn.QualName(mod_name, shortname)
result = getter(self, fqname)
if result is not None:
return result

# If a module was specified in the name, ensure that no base module
# of the same name exists.
#
# If no module was specified, try the default module name as a part
# of std. The same condition applies.
if module and not (
self.has_module(fmod := module.split('::')[0])
or (disallow_module and disallow_module(fmod))
):
mod_name = (
f'std::{module}'
if orig_module is None
else f'std::{orig_module}'
)
fqname = sn.QualName(mod_name, shortname)
result = getter(self, fqname)
if result is not None:
return result

return default

Expand Down
193 changes: 193 additions & 0 deletions tests/test_edgeql_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9905,3 +9905,196 @@ async def test_edgeql_cast_to_function_01(self):
await self.con.execute(f"""
select <cal::to_local_date>1;
""")

async def test_edgeql_expr_with_module_01(self):
await self.con.execute(f"""
create module dummy;
""")

valid_queries = [
'SELECT <int64>{} = 1',
'SELECT <std::int64>{} = 1',
'WITH MODULE dummy SELECT <int64>{} = 1',
'WITH MODULE dummy SELECT <std::int64>{} = 1',
'WITH MODULE std SELECT <int64>{} = 1',
'WITH MODULE std SELECT <std::int64>{} = 1',
]

for query in valid_queries:
await self.con.execute(query)

async def test_edgeql_expr_with_module_02(self):
await self.con.execute(f"""
create module dummy;
create type default::int64;
""")

valid_queries = [
'SELECT <std::int64>{} = 1',
'WITH MODULE dummy SELECT <int64>{} = 1',
'WITH MODULE dummy SELECT <std::int64>{} = 1',
'WITH MODULE std SELECT <int64>{} = 1',
'WITH MODULE std SELECT <std::int64>{} = 1',
]
invalid_queries = [
'SELECT <int64>{} = 1',
'SELECT <default::int64>{} = 1',
'WITH MODULE dummy SELECT <default::int64>{} = 1',
'WITH MODULE std SELECT <default::int64>{} = 1',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidTypeError,
"operator '=' cannot be applied",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_03(self):
await self.con.execute(f"""
create module dummy;
""")

valid_queries = [
'select _test::abs(1)',
'select std::_test::abs(1)',
'with module dummy select _test::abs(1)',
'with module dummy select std::_test::abs(1)',
'with module _test select abs(1)',
'with module _test select _test::abs(1)',
'with module _test select std::_test::abs(1)',
'with module std select _test::abs(1)',
'with module std select std::_test::abs(1)',
'with module std::_test select abs(1)',
'with module std::_test select _test::abs(1)',
'with module std::_test select std::_test::abs(1)',
]
invalid_queries = [
'select abs(1)',
'with module dummy select abs(1)',
'with module std select abs(1)',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"abs' does not exist",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_04(self):
await self.con.execute(f"""
create module dummy;
create module _test;
""")

valid_queries = [
'select std::_test::abs(1)',
'with module dummy select std::_test::abs(1)',
'with module _test select std::_test::abs(1)',
'with module std select std::_test::abs(1)',
'with module std::_test select abs(1)',
'with module std::_test select std::_test::abs(1)',
]
invalid_queries = [
'select abs(1)',
'select _test::abs(1)',
'with module dummy select abs(1)',
'with module dummy select _test::abs(1)',
'with module _test select abs(1)',
'with module _test select _test::abs(1)',
'with module std select abs(1)',
'with module std select _test::abs(1)',
'with module std::_test select _test::abs(1)',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"abs' does not exist",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_05(self):
await self.con.execute(f"""
create module dummy;
create module std::test;
create scalar type std::test::Foo extending int64;
""")

valid_queries = [
'select <test::Foo>1',
'select <std::test::Foo>1',
'with module dummy select <test::Foo>1',
'with module dummy select <std::test::Foo>1',
'with module test select <Foo>1',
'with module test select <test::Foo>1',
'with module test select <std::test::Foo>1',
'with module std select <test::Foo>1',
'with module std select <std::test::Foo>1',
'with module std::test select <Foo>1',
'with module std::test select <test::Foo>1',
'with module std::test select <std::test::Foo>1',
]
invalid_queries = [
'select <Foo>1',
'with module dummy select <Foo>1',
'with module std select <Foo>1',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"Foo' does not exist",
):
await self.con.execute(query)

async def test_edgeql_expr_with_module_06(self):
await self.con.execute(f"""
create module dummy;
create module std::test;
create scalar type std::test::Foo extending int64;
create module test;
""")

valid_queries = [
'select <std::test::Foo>1',
'with module dummy select <std::test::Foo>1',
'with module test select <std::test::Foo>1',
'with module std select <std::test::Foo>1',
'with module std::test select <Foo>1',
'with module std::test select <std::test::Foo>1',
]
invalid_queries = [
'select <Foo>1',
'select <test::Foo>1',
'with module dummy select <Foo>1',
'with module dummy select <test::Foo>1',
'with module test select <Foo>1',
'with module test select <test::Foo>1',
'with module std select <Foo>1',
'with module std select <test::Foo>1',
'with module std::test select <test::Foo>1',
]

for query in valid_queries:
await self.con.execute(query)

for query in invalid_queries:
async with self.assertRaisesRegexTx(
edgedb.errors.InvalidReferenceError,
"Foo' does not exist",
):
await self.con.execute(query)
Loading