Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Nov 15, 2024
1 parent d55ee99 commit 6ce3c9a
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 9 deletions.
3 changes: 2 additions & 1 deletion edb/pgsql/resolver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class Options:
# apply access policies to select & dml statements
apply_access_policies: bool

for_edgedb_protocol: bool
# makes sure that output does not contain duplicated column names
disambiguate_column_names: bool


@dataclass(kw_only=True)
Expand Down
9 changes: 8 additions & 1 deletion edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def _resolve_ResTarget(
rel_var_name = table.alias or table.name
if rel_var_name:
nam = rel_var_name + '_' + nam
if nam in existing_names:
if ctx.options.disambiguate_column_names:
raise errors.QueryError(
f'duplicate column name: `{nam}`',
span=res_target.span,
pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE,
)
existing_names.add(nam)

res.append(
Expand Down Expand Up @@ -145,7 +152,7 @@ def _resolve_ResTarget(

if res_target.name:
# explicit duplicate name: error out
if ctx.options.for_edgedb_protocol:
if ctx.options.disambiguate_column_names:
raise errors.QueryError(
f'duplicate column name: `{alias}`',
span=res_target.span,
Expand Down
1 change: 1 addition & 0 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def compile_sql(
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
disambiguate_column_names=False,
)

def compile_serialized_request(
Expand Down
5 changes: 4 additions & 1 deletion edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ def compile_sql(
current_user: str,
allow_user_specified_id: Optional[bool],
apply_access_policies_sql: Optional[bool],
disambiguate_column_names: bool,
) -> List[dbstate.SQLQueryUnit]:
opts = ResolverOptionsPartial(
query_str=query_str,
current_database=current_database,
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
disambiguate_column_names=disambiguate_column_names,
)

stmts = pg_parser.parse(query_str, propagate_spans=True)
Expand Down Expand Up @@ -274,6 +276,7 @@ class ResolverOptionsPartial:
query_str: str
allow_user_specified_id: Optional[bool]
apply_access_policies_sql: Optional[bool]
disambiguate_column_names: bool


def resolve_query(
Expand Down Expand Up @@ -314,7 +317,7 @@ def resolve_query(
search_path=search_path,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies=apply_access_policies,
for_edgedb_protocol=True,
disambiguate_column_names=opts.disambiguate_column_names,
)
resolved = pg_resolver.resolve(stmt, schema, options)
source = pg_codegen.generate(resolved.ast, with_translation_data=True)
Expand Down
23 changes: 17 additions & 6 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,12 +872,12 @@ async def test_sql_query_44(self):
await self.squery_values('SELECT name FROM User')

async def test_sql_query_45(self):
with self.assertRaisesRegex(
asyncpg.InvalidColumnReferenceError,
'duplicate column name: `a`',
position="16",
):
await self.squery_values('SELECT 1 AS a, 2 AS a')
await self.squery_values('SELECT 1 AS a, 2 AS a')

# Over edgedb-protocol, this query should raise:
# asyncpg.InvalidColumnReferenceError
# 'duplicate column name: `a`'
# position="16"

async def test_sql_query_46(self):
res = await self.scon.fetch(
Expand Down Expand Up @@ -918,6 +918,11 @@ async def test_sql_query_48(self):
# duplicate rel var names can yield duplicate column names
self.assert_shape(res, 4, ['a', 'y_a', 'y_a'])

# Over edgedb-protocol, this query should raise:
# asyncpg.InvalidColumnReferenceError
# 'duplicate column name: `y_a`'
# position="114"

async def test_sql_query_49(self):
res = await self.scon.fetch(
'''
Expand All @@ -930,6 +935,11 @@ async def test_sql_query_49(self):
# duplicate rel var names can yield duplicate column names
self.assert_shape(res, 1, ['x_a', 'a', 'x_a'])

# Over edgedb-protocol, this query should raise:
# asyncpg.InvalidColumnReferenceError
# 'duplicate column name: `x_a`'
# position="83"

async def test_sql_query_50(self):
res = await self.scon.fetch(
'''
Expand Down Expand Up @@ -1332,6 +1342,7 @@ async def test_sql_query_static_eval_03(self):
SELECT information_schema._pg_truetypid(a.*, t.*)
FROM pg_attribute a
JOIN pg_type t ON t.oid = a.atttypid
LIMIT 500
'''
)

Expand Down

0 comments on commit 6ce3c9a

Please sign in to comment.