Skip to content

Commit

Permalink
Require extension modules to live in ext:: (#7526)
Browse files Browse the repository at this point in the history
We assume this in some places but one of our tests violated it.
  • Loading branch information
msullivan authored Jul 3, 2024
1 parent ce9112c commit 6acf5d5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
11 changes: 11 additions & 0 deletions edb/schema/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ def _create_begin(

if not context.canonical:
package = self.scls.get_package(schema)

module = package.get_ext_module(schema)
if module:
module_name = sn.UnqualName(module)
if module_name.get_root_module_name() != s_schema.EXT_MODULE:
raise errors.SchemaError(
f'built-in extension {self.classname} has invalid '
f'module "{module}": '
f'extension modules must begin with "ext::"'
)

script = package.get_script(schema)
if script:
block, _ = qlparser.parse_extension_package_body_block(script)
Expand Down
9 changes: 6 additions & 3 deletions edb/schema/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def from_string(cls: Type[NameT], name: str) -> NameT:
def get_local_name(self) -> UnqualName:
...

def get_root_module_name(self) -> UnqualName:
...

def __lt__(self, other: Any) -> bool:
...

Expand Down Expand Up @@ -92,9 +95,6 @@ def get_local_name(self) -> UnqualName:
def get_module_name(self) -> Name:
...

def get_root_module_name(self) -> UnqualName:
...

class UnqualName(Name):

__slots__ = ('name',)
Expand Down Expand Up @@ -173,6 +173,9 @@ def from_string(
def get_local_name(self) -> UnqualName:
return self

def get_root_module_name(self) -> UnqualName:
return UnqualName(self.name.partition('::')[0])

def __str__(self) -> str:
return self.name

Expand Down
45 changes: 24 additions & 21 deletions tests/test_edgeql_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,50 @@ async def _extension_test_01(self):

await self.assert_query_result(
'''
select ltree::nlevel(
<ltree::ltree><json><ltree::ltree>'foo.bar');
select ext::ltree::nlevel(
<ext::ltree::ltree><json><ext::ltree::ltree>'foo.bar');
''',
[2],
)
await self.assert_query_result(
'''
select ltree::asdf(
<ltree::ltree><json><ltree::ltree>'foo.bar');
select ext::ltree::asdf(
<ext::ltree::ltree><json><ext::ltree::ltree>'foo.bar');
''',
[3],
)
await self.assert_query_result(
'''
select <str>(
<ltree::ltree><json><ltree::ltree>'foo.bar');
<ext::ltree::ltree><json><ext::ltree::ltree>'foo.bar');
''',
['foo.bar'],
)
await self.assert_query_result(
'''
select <ltree::ltree>'foo.bar' = <ltree::ltree>'foo.baz';
select <ext::ltree::ltree>'foo.bar'
= <ext::ltree::ltree>'foo.baz';
''',
[False],
)
await self.assert_query_result(
'''
select <ltree::ltree>'foo.bar' != <ltree::ltree>'foo.baz';
select <ext::ltree::ltree>'foo.bar'
!= <ext::ltree::ltree>'foo.baz';
''',
[True],
)
await self.assert_query_result(
'''
select <ltree::ltree><json><ltree::ltree>'foo.bar';
select <ext::ltree::ltree><json><ext::ltree::ltree>'foo.bar';
''',
[['foo', 'bar']],
json_only=True,
)

await self.con.execute('''
create type Foo { create property x -> ltree::ltree };
insert Foo { x := <ltree::ltree>'foo.bar.baz' };
create type Foo { create property x -> ext::ltree::ltree };
insert Foo { x := <ext::ltree::ltree>'foo.bar.baz' };
''')

await self.assert_query_result(
Expand All @@ -96,7 +98,7 @@ async def test_edgeql_extensions_01(self):
# Make an extension that wraps a tiny bit of the ltree package.
await self.con.execute('''
create extension package ltree VERSION '1.0' {
set ext_module := "ltree";
set ext_module := "ext::ltree";
set sql_extensions := ["ltree >=1.0,<10.0"];
set sql_setup_script := $$
Expand All @@ -113,27 +115,27 @@ async def test_edgeql_extensions_01(self):
DROP FUNCTION edgedb.asdf(edgedb.ltree);
$$;
create module ltree;
create scalar type ltree::ltree extending anyscalar {
create module ext::ltree;
create scalar type ext::ltree::ltree extending anyscalar {
set sql_type := "ltree";
};
create cast from ltree::ltree to std::str {
create cast from ext::ltree::ltree to std::str {
SET volatility := 'Immutable';
USING SQL CAST;
};
create cast from std::str to ltree::ltree {
create cast from std::str to ext::ltree::ltree {
SET volatility := 'Immutable';
USING SQL CAST;
};
# Use a non-trivial json representation just to show that we can.
create cast from ltree::ltree to std::json {
create cast from ext::ltree::ltree to std::json {
SET volatility := 'Immutable';
USING SQL $$
select to_jsonb(string_to_array("val"::text, '.'));
$$
};
create cast from std::json to ltree::ltree {
create cast from std::json to ext::ltree::ltree {
SET volatility := 'Immutable';
USING SQL $$
select string_agg(edgedb.raise_on_null(
Expand All @@ -145,10 +147,11 @@ async def test_edgeql_extensions_01(self):
as z(z);
$$
};
create function ltree::nlevel(v: ltree::ltree) -> std::int32 {
create function ext::ltree::nlevel(
v: ext::ltree::ltree) -> std::int32 {
using sql function 'edgedb.nlevel';
};
create function ltree::asdf(v: ltree::ltree) -> std::int32 {
create function ext::ltree::asdf(v: ext::ltree::ltree) -> std::int32 {
using sql function 'edgedb.asdf';
};
};
Expand Down Expand Up @@ -366,7 +369,7 @@ async def test_edgeql_extensions_02(self):
async def test_edgeql_extensions_03(self):
await self.con.execute('''
create extension package ltree_broken VERSION '1.0' {
set ext_module := "ltree";
set ext_module := "ext::ltree";
set sql_extensions := ["ltree >=1000.0"];
create module ltree;
};
Expand All @@ -387,7 +390,7 @@ async def test_edgeql_extensions_03(self):
async def test_edgeql_extensions_04(self):
await self.con.execute('''
create extension package ltree_broken VERSION '1.0' {
set ext_module := "ltree";
set ext_module := "ext::ltree";
set sql_extensions := ["loltree >=1.0"];
create module ltree;
};
Expand Down

0 comments on commit 6acf5d5

Please sign in to comment.