From 6e0cdaa616879e7bb50d4dcc107ddc270adf6bf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Mon, 11 Nov 2024 19:27:27 +0100 Subject: [PATCH] Fix an edge case of calling value functions from range vars (#7982) --- edb/pgsql/ast.py | 2 +- edb/pgsql/parser/ast_builder.py | 2 +- edb/pgsql/resolver/expr.py | 6 ++++-- edb/pgsql/resolver/range_var.py | 24 +++++++++++++----------- tests/test_sql_query.py | 11 +++++++++++ 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/edb/pgsql/ast.py b/edb/pgsql/ast.py index d97157b580d..cc87b61b442 100644 --- a/edb/pgsql/ast.py +++ b/edb/pgsql/ast.py @@ -949,7 +949,7 @@ class RangeFunction(BaseRangeVar): with_ordinality: bool = False # ROWS FROM form is_rowsfrom: bool = False - functions: typing.List[FuncCall] + functions: typing.List[BaseExpr] class JoinClause(BaseRangeVar): diff --git a/edb/pgsql/parser/ast_builder.py b/edb/pgsql/parser/ast_builder.py index ea4a23d3826..b7688634497 100644 --- a/edb/pgsql/parser/ast_builder.py +++ b/edb/pgsql/parser/ast_builder.py @@ -905,7 +905,7 @@ def _build_range_function(n: Node, c: Context) -> pgast.RangeFunction: with_ordinality=_bool_or_false(n, "ordinality"), is_rowsfrom=_bool_or_false(n, "is_rowsfrom"), functions=[ - _build_func_call(fn, c) + _build_base_expr(fn, c) for fn in n["functions"][0]["List"]["items"] if len(fn) > 0 ], diff --git a/edb/pgsql/resolver/expr.py b/edb/pgsql/resolver/expr.py index 69525b71a06..8edb6f96bd7 100644 --- a/edb/pgsql/resolver/expr.py +++ b/edb/pgsql/resolver/expr.py @@ -34,9 +34,10 @@ from edb.ir import ast as irast - from edb.edgeql import compiler as qlcompiler +from edb.server.pgcon import errors as pgerror + from . import dispatch from . import context from . import static @@ -260,7 +261,8 @@ def _lookup_column( if not matched_columns: raise errors.QueryError( - f'cannot find column `{col_name}`', span=column_ref.span + f'cannot find column `{col_name}`', span=column_ref.span, + pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE, ) # apply precedence diff --git a/edb/pgsql/resolver/range_var.py b/edb/pgsql/resolver/range_var.py index 3ac53c09bda..63b1eacd4a8 100644 --- a/edb/pgsql/resolver/range_var.py +++ b/edb/pgsql/resolver/range_var.py @@ -298,19 +298,21 @@ def _resolve_RangeFunction( ) -> Tuple[pgast.BaseRangeVar, context.Table]: with ctx.lateral() if range_var.lateral else ctx.child() as subctx: - functions = [] + functions: List[pgast.BaseExpr] = [] col_names = [] for function in range_var.functions: - - name = function.name[len(function.name) - 1] - if name in range_functions.COLUMNS: - col_names.extend(range_functions.COLUMNS[name]) - elif name == 'unnest': - col_names.extend('unnest' for _ in function.args) - else: - col_names.append(name) - - functions.append(dispatch.resolve(function, ctx=subctx)) + match function: + case pgast.FuncCall(): + name = function.name[len(function.name) - 1] + if name in range_functions.COLUMNS: + col_names.extend(range_functions.COLUMNS[name]) + elif name == 'unnest': + col_names.extend('unnest' for _ in function.args) + else: + col_names.append(name) + functions.append(dispatch.resolve(function, ctx=subctx)) + case _: + functions.append(dispatch.resolve(function, ctx=subctx)) inferred_columns = [ context.Column( diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index a556a90e31a..79a7326c9ec 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -860,6 +860,17 @@ async def test_sql_query_43(self): ) self.assertEqual(res, [[1, 1, 1], [2, None, 2], [None, 3, 3]]) + async def test_sql_query_44(self): + # range function that is an "sql value function", whatever that is + + # to be exact: User is *parsed* as function call CURRENT_USER + # we'd ideally want a message that hints that it should use quotes + + with self.assertRaisesRegex( + asyncpg.InvalidColumnReferenceError, 'cannot find column `name`' + ): + await self.squery_values('SELECT name FROM User') + async def test_sql_query_introspection_00(self): dbname = self.con.dbname res = await self.squery_values(