Skip to content

Commit

Permalink
Fix an edge case of calling value functions from range vars (#7982)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Nov 11, 2024
1 parent c386471 commit 6e0cdaa
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 15 deletions.
2 changes: 1 addition & 1 deletion edb/pgsql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ class RangeFunction(BaseRangeVar):
with_ordinality: bool = False
# ROWS FROM form
is_rowsfrom: bool = False
functions: typing.List[FuncCall]
functions: typing.List[BaseExpr]


class JoinClause(BaseRangeVar):
Expand Down
2 changes: 1 addition & 1 deletion edb/pgsql/parser/ast_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def _build_range_function(n: Node, c: Context) -> pgast.RangeFunction:
with_ordinality=_bool_or_false(n, "ordinality"),
is_rowsfrom=_bool_or_false(n, "is_rowsfrom"),
functions=[
_build_func_call(fn, c)
_build_base_expr(fn, c)
for fn in n["functions"][0]["List"]["items"]
if len(fn) > 0
],
Expand Down
6 changes: 4 additions & 2 deletions edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@

from edb.ir import ast as irast


from edb.edgeql import compiler as qlcompiler

from edb.server.pgcon import errors as pgerror

from . import dispatch
from . import context
from . import static
Expand Down Expand Up @@ -260,7 +261,8 @@ def _lookup_column(

if not matched_columns:
raise errors.QueryError(
f'cannot find column `{col_name}`', span=column_ref.span
f'cannot find column `{col_name}`', span=column_ref.span,
pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE,
)

# apply precedence
Expand Down
24 changes: 13 additions & 11 deletions edb/pgsql/resolver/range_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,19 +298,21 @@ def _resolve_RangeFunction(
) -> Tuple[pgast.BaseRangeVar, context.Table]:
with ctx.lateral() if range_var.lateral else ctx.child() as subctx:

functions = []
functions: List[pgast.BaseExpr] = []
col_names = []
for function in range_var.functions:

name = function.name[len(function.name) - 1]
if name in range_functions.COLUMNS:
col_names.extend(range_functions.COLUMNS[name])
elif name == 'unnest':
col_names.extend('unnest' for _ in function.args)
else:
col_names.append(name)

functions.append(dispatch.resolve(function, ctx=subctx))
match function:
case pgast.FuncCall():
name = function.name[len(function.name) - 1]
if name in range_functions.COLUMNS:
col_names.extend(range_functions.COLUMNS[name])
elif name == 'unnest':
col_names.extend('unnest' for _ in function.args)
else:
col_names.append(name)
functions.append(dispatch.resolve(function, ctx=subctx))
case _:
functions.append(dispatch.resolve(function, ctx=subctx))

inferred_columns = [
context.Column(
Expand Down
11 changes: 11 additions & 0 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,17 @@ async def test_sql_query_43(self):
)
self.assertEqual(res, [[1, 1, 1], [2, None, 2], [None, 3, 3]])

async def test_sql_query_44(self):
# range function that is an "sql value function", whatever that is

# to be exact: User is *parsed* as function call CURRENT_USER
# we'd ideally want a message that hints that it should use quotes

with self.assertRaisesRegex(
asyncpg.InvalidColumnReferenceError, 'cannot find column `name`'
):
await self.squery_values('SELECT name FROM User')

async def test_sql_query_introspection_00(self):
dbname = self.con.dbname
res = await self.squery_values(
Expand Down

0 comments on commit 6e0cdaa

Please sign in to comment.