Skip to content

Commit

Permalink
Improve error when applying a shape to a parameter (#7044)
Browse files Browse the repository at this point in the history
* Add Params data class.

* Add error message for shaped parameters.

* Add tests for new error.

* Add hint and change error span to cover shape as well.
  • Loading branch information
dnwpark authored Mar 15, 2024
1 parent f8676b2 commit d7515df
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 13 deletions.
30 changes: 21 additions & 9 deletions edb/edgeql/compiler/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Tuple, Dict, List

from edb.common import ast
Expand Down Expand Up @@ -121,16 +122,24 @@ def extend_path(expr: qlast.Expr, field: str) -> qlast.Path:
return qlast.Path(steps=[expr, step])


@dataclass
class Params:
cast_params: List[
Tuple[qlast.TypeCast, Dict[Optional[str], str]]
] = field(default_factory=list)
shaped_params: List[
Tuple[qlast.Parameter, qlast.Shape]
] = field(default_factory=list)
loose_params: List[qlast.Parameter] = field(default_factory=list)

class FindParams(ast.NodeVisitor):
"""Visitor to find all the parameters.
The annoying bit is that we also need all the modaliases.
"""
def __init__(self, modaliases: Dict[Optional[str], str]) -> None:
super().__init__()
self.params: List[
Tuple[qlast.TypeCast, Dict[Optional[str], str]]] = []
self.loose_params: List[qlast.Parameter] = []
self.params: Params = Params()
self.modaliases = modaliases

def visit_Command(self, n: qlast.Command) -> None:
Expand All @@ -154,12 +163,17 @@ def _visit_with_stmt(self, n: qlast.Statement) -> None:

def visit_TypeCast(self, n: qlast.TypeCast) -> None:
if isinstance(n.expr, qlast.Parameter):
self.params.append((n, self.modaliases))
self.params.cast_params.append((n, self.modaliases))
elif isinstance(n.expr, qlast.Shape):
if isinstance(n.expr.expr, qlast.Parameter):
self.params.shaped_params.append((n.expr.expr, n.expr))
else:
self.generic_visit(n)
else:
self.generic_visit(n)

def visit_Parameter(self, n: qlast.Parameter) -> None:
self.loose_params.append(n)
self.params.loose_params.append(n)

def visit_CreateFunction(self, n: qlast.CreateFunction) -> None:
pass
Expand All @@ -170,13 +184,11 @@ def visit_CreateConstraint(self, n: qlast.CreateFunction) -> None:

def find_parameters(
ql: qlast.Base, modaliases: Dict[Optional[str], str]
) -> Tuple[
List[Tuple[qlast.TypeCast, Dict[Optional[str], str]]],
List[qlast.Parameter]]:
) -> Params:
"""Get all query parameters"""
v = FindParams(modaliases)
v.visit(ql)
return v.params, v.loose_params
return v.params


class alias_view(
Expand Down
25 changes: 21 additions & 4 deletions edb/edgeql/compiler/stmtctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,17 @@ def check_params(params: Dict[str, irast.Param]) -> None:
f'{"s" if len(missing_args) > 1 else ""}')


def throw_on_shaped_param(
param: qlast.Parameter,
shape: qlast.Shape,
ctx: context.ContextLevel
) -> None:
raise errors.QueryError(
f'cannot apply a shape to the parameter',
hint='Consider adding parentheses around the parameter and type cast',
context=shape.context)


def throw_on_loose_param(
param: qlast.Parameter,
ctx: context.ContextLevel
Expand Down Expand Up @@ -846,19 +857,25 @@ def preprocess_script(
Doing this in advance makes it easy to check that they have
consistent types.
"""
param_lists = [
params_lists = [
astutils.find_parameters(stmt, ctx.modaliases)
for stmt in stmts
]

if loose_params := [
loose for _, loose_list in param_lists
for loose in loose_list
loose for params in params_lists
for loose in params.loose_params
]:
throw_on_loose_param(loose_params[0], ctx)

if shaped_params := [
shaped for params in params_lists
for shaped in params.shaped_params
]:
throw_on_shaped_param(shaped_params[0][0], shaped_params[0][1], ctx)

casts = [
cast for cast_lists, _ in param_lists for cast in cast_lists
cast for params in params_lists for cast in params.cast_params
]
params = {}
for cast, modaliases in casts:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_edgeql_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -8212,6 +8212,12 @@ async def test_edgeql_select_params_03(self):
with self.assertRaisesRegex(edgedb.QueryError, "missing a type cast"):
await self.con.query("select ($0, <std::int64>$0)")

async def test_edgeql_select_params_04(self):
with self.assertRaisesRegex(edgedb.QueryError,
"cannot apply a shape to the parameter"):
await self.con.query("select <std::int64>$0 { id }")


async def test_edgeql_type_pointer_inlining_01(self):
await self.con._fetchall(
r'''
Expand Down
7 changes: 7 additions & 0 deletions tests/test_server_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,13 @@ async def test_server_proto_args_07(self):
await self.con.query_single(
'select schema::Object {name} filter .id=$id', id='asd')

async def test_server_proto_args_07_1(self):
with self.assertRaisesRegex(edgedb.QueryError,
"cannot apply a shape to the parameter"):
await self.con.query_single(
'select schema::Object filter .id=<uuid>$id {name}', id='asd')


async def test_server_proto_args_08(self):
async with self._run_and_rollback():
await self.con.execute(
Expand Down

0 comments on commit d7515df

Please sign in to comment.