Skip to content

Commit

Permalink
Add array index number to output when casting from json to array.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed May 24, 2024
1 parent d8ae9ec commit 542ebb1
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 40 deletions.
152 changes: 124 additions & 28 deletions edb/edgeql/compiler/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _cast_to_ir(
sql_function=cast.get_from_function(ctx.env.schema),
sql_cast=cast.get_from_cast(ctx.env.schema),
sql_expr=bool(cast.get_code(ctx.env.schema)),
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand All @@ -479,7 +479,7 @@ def _inheritance_cast_to_ir(
sql_function=None,
sql_cast=True,
sql_expr=False,
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand Down Expand Up @@ -663,12 +663,8 @@ def _cast_json_to_tuple(
source_path,
qlast.Constant.boolean(allow_null),
]
if error_message_context := cast_message_context(subctx):
json_object_args.append(qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
))
if error_message_context_ql := get_error_message_context_ql(subctx):
json_object_args.append(error_message_context_ql)
json_objects = qlast.FunctionCall(
func=('__std__', '__tuple_validate_json'),
args=json_object_args
Expand Down Expand Up @@ -698,12 +694,8 @@ def _cast_json_to_tuple(
subctx.collection_cast_info.path_elements.append(cast_element)

json_get_kwargs: dict[str, qlast.Expr] = {}
if error_message_context := cast_message_context(subctx):
json_get_kwargs['detail'] = qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
)
if error_message_context_ql := get_error_message_context_ql(subctx):
json_get_kwargs['detail'] = error_message_context_ql
val_e = qlast.FunctionCall(
func=('__std__', '__json_get_not_null'),
args=[
Expand Down Expand Up @@ -995,12 +987,8 @@ def _cast_json_to_range(
source_path = subctx.create_anchor(ir_set, 'a')

check_args: list[qlast.Expr] = [source_path]
if error_message_context := cast_message_context(subctx):
check_args.append(qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
))
if error_message_context_ql := get_error_message_context_ql(subctx):
check_args.append(error_message_context_ql)
check = qlast.FunctionCall(
func=('__std__', '__range_validate_json'),
args=check_args
Expand Down Expand Up @@ -1274,9 +1262,25 @@ def _cast_array(
if el_type.contains_json(subctx.env.schema):
subctx.inhibit_implicit_limit = True

if subctx.collection_cast_info is not None:
subctx.collection_cast_info.path_elements.append(
(
'array_index',
qlast.FunctionCall(
func=('__std__', 'to_str'),
args=[
astutils.extend_path(enumerated_ref, '0')
],
),
)
)

array_ir = dispatch.compile(correlated_query, ctx=subctx)
assert isinstance(array_ir, irast.Set)

if subctx.collection_cast_info is not None:
subctx.collection_cast_info.path_elements.pop()

if direct_cast is not None:
ctx.env.schema, array_stype = s_types.Array.from_subtypes(
ctx.env.schema, [el_type])
Expand Down Expand Up @@ -1342,7 +1346,7 @@ def _cast_array_literal(
sql_cast=True,
sql_expr=False,
span=span,
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand Down Expand Up @@ -1385,7 +1389,7 @@ def _cast_enum_str_immutable(
sql_function=None,
sql_cast=False,
sql_expr=True,
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand Down Expand Up @@ -1450,7 +1454,7 @@ def _find_object_by_id(
return dispatch.compile(for_query, ctx=subctx)


def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
def get_error_message_context(ctx: context.ContextLevel) -> Optional[str]:
if (
ctx.collection_cast_info is not None
and ctx.collection_cast_info.path_elements
Expand All @@ -1462,8 +1466,13 @@ def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
ctx.collection_cast_info.to_type.get_displayname(ctx.env.schema)
)
path_msg = ''.join(
_collection_element_message_context(path_element)
for path_element in ctx.collection_cast_info.path_elements
element_message_context
for element_message_context in [
_element_message_context(path_element)
for path_element
in ctx.collection_cast_info.path_elements
]
if element_message_context is not None
)
return (
f"while casting '{from_name}' to '{to_name}', {path_msg}"
Expand All @@ -1472,14 +1481,101 @@ def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
return None


def _collection_element_message_context(
path_element: Tuple[str, Optional[str]]
) -> str:
def get_error_message_context_ql(
ctx: context.ContextLevel
) -> Optional[qlast.Expr]:
if (
ctx.collection_cast_info is not None
and ctx.collection_cast_info.path_elements
and any(
e[0] == 'array_index'
for e in ctx.collection_cast_info.path_elements
)
):
from_name = (
ctx.collection_cast_info.from_type.get_displayname(ctx.env.schema)
)
to_name = (
ctx.collection_cast_info.to_type.get_displayname(ctx.env.schema)
)

path_element_messages: list[qlast.Expr] = [
_element_message_context_ql(path_element)
for path_element in ctx.collection_cast_info.path_elements
]

return qlast.FunctionCall(
func=('__std__', 'array_join'),
args=[
qlast.Array(
elements=(
[qlast.Constant.string(
'{"error_message_context": "'
)]
+ [qlast.Constant.string(
f"while casting '{from_name}' to '{to_name}', "
)]
+ path_element_messages
+ [qlast.Constant.string('"}')]
),
),
qlast.Constant.string(""),
],
)

elif error_message_context := get_error_message_context(ctx):
return qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
)

return None


def get_error_message_context_ir(
ctx: context.ContextLevel
) -> Optional[irast.Set]:
if error_message_context_ql := get_error_message_context_ql(ctx):
return dispatch.compile(error_message_context_ql, ctx=ctx)

return None


def _element_message_context(
path_element: Tuple[str, Optional[str | qlast.Expr]]
) -> Optional[str]:
if path_element[0] == 'tuple':
assert isinstance(path_element[1], str)
return f"at tuple element '{path_element[1]}', "
elif path_element[0] == 'array':
return f'in array elements, '
elif path_element[0] == 'array_index':
return None
elif path_element[0] == 'range':
assert isinstance(path_element[1], str)
return f"in range parameter '{path_element[1]}', "
else:
raise NotImplementedError


def _element_message_context_ql(
path_element: Tuple[str, Optional[str | qlast.Expr]]
) -> qlast.Expr:
if path_element[0] == 'array_index':
assert isinstance(path_element[1], qlast.Expr)
return qlast.BinOp(
op='++',
left=qlast.Constant.string("at index "),
right=qlast.BinOp(
op='++',
left=path_element[1],
right=qlast.Constant.string(", "),
)
)

elif element_message_context := _element_message_context(path_element):
return qlast.Constant.string(element_message_context)

else:
raise NotImplementedError
6 changes: 5 additions & 1 deletion edb/edgeql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,16 @@ class CollectionCastInfo(NamedTuple):
from_type: s_types.Type
to_type: s_types.Type

path_elements: list[Tuple[str, Optional[str]]]
path_elements: list[Tuple[str, Optional[str | qlast.Expr]]]
"""Represents a path to the current collection element being cast.
A path element is a tuple of the collection type and an optional
element name. eg. ('tuple', 'a') or ('array', None)
When casting a json array, the path element has the form
('array_index', expr) where expr is a string representation of the index
in the current json array.
The list is shared between the outermost context and all its sub contexts.
When casting a collection, each element's path should be pushed before
entering the "sub-cast" and popped immediately after.
Expand Down
2 changes: 1 addition & 1 deletion edb/edgeql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def compile_TypeCast(

except errors.QueryError as e:
if (
(message_context := casts.cast_message_context(subctx))
(message_context := casts.get_error_message_context(subctx))
and use_message_context
):
e.args = (
Expand Down
5 changes: 4 additions & 1 deletion edb/edgeql/compiler/stmtctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

from edb.edgeql import ast as qlast

from edb.common.ast import visitor as ast_visitor
from edb.common.ast import visitor as ast_visitor, find_children
from edb.common import ordered
from edb.common.typeutils import not_none

Expand Down Expand Up @@ -191,6 +191,9 @@ def fini_expression(
if p.sub_params and p.sub_params.decoder_ir
]
extra_exprs += [trigger.expr for stage in ir_triggers for trigger in stage]
for type_cast in find_children(ir, irast.TypeCast):
if type_cast.error_message_context is not None:
extra_exprs.append(type_cast.error_message_context)

all_exprs = [ir] + extra_exprs

Expand Down
2 changes: 1 addition & 1 deletion edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ class TypeCast(ImmutableExpr):
sql_function: typing.Optional[str] = None
sql_cast: bool
sql_expr: bool
error_message_context: typing.Optional[str] = None
error_message_context: typing.Optional[Set] = None

@property
def typeref(self) -> TypeRef:
Expand Down
10 changes: 2 additions & 8 deletions edb/pgsql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,9 @@ def compile_TypeCast(

pg_expr = dispatch.compile(expr.expr, ctx=ctx)

detail: Optional[pgast.StringConstant] = None
detail: Optional[pgast.BaseExpr] = None
if expr.error_message_context is not None:
detail = pgast.StringConstant(
val=(
'{"error_message_context": "'
+ expr.error_message_context
+ '"}'
)
)
detail = dispatch.compile(expr.error_message_context, ctx=ctx)

if expr.sql_cast:
# Use explicit SQL cast.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_edgeql_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 0, "
r"expected JSON number or null; got JSON string"):
await self.con.query_single(
r"SELECT <array<int64>><json>['asdf']")
Expand All @@ -2469,6 +2470,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 2, "
r"expected JSON number or null; got JSON string"):
await self.con.query_single(
r"SELECT <array<int64>>to_json('[1, 2, \"asdf\"]')")
Expand All @@ -2478,6 +2480,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 0, "
r"expected JSON number or null; got JSON string"):
await self.con.execute("""
SELECT <array<int64>>to_json('["a"]');
Expand All @@ -2497,6 +2500,7 @@ async def test_edgeql_casts_json_11(self):
edgedb.InvalidValueError,
r"array<std::int64>', "
r"in array elements, "
r"at index 2, "
r"invalid null value in cast"):
await self.con.query_single(
r"SELECT <array<int64>>to_json('[1, 2, null]')")
Expand All @@ -2506,6 +2510,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'array<std::json>' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 2, "
r"invalid null value in cast"):
await self.con.query_single(
r"SELECT <array<int64>><array<json>>to_json('[1, 2, null]')")
Expand All @@ -2525,6 +2530,7 @@ async def test_edgeql_casts_json_11(self):
r"to 'tuple<array<std::str>>', "
r"at tuple element '0', "
r"in array elements, "
r"at index 0, "
r"invalid null value in cast"):
await self.con.query_single(
r"select <tuple<array<str>>>to_json('[[null]]')")
Expand All @@ -2543,6 +2549,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 0, "
r"expected JSON number or null; got JSON object"):
await self.con.execute("""
SELECT <array<int64>>to_json('[{"a": 1}]');
Expand All @@ -2554,8 +2561,10 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<tuple<array<std::str>>>', "
r"in array elements, "
r"at index 0, "
r"at tuple element '0', "
r"in array elements, "
r"at index 0, "
r"expected JSON string or null; got JSON number"):
await self.con.execute("""
SELECT <array<tuple<array<str>>>>to_json('[[[1]]]');
Expand Down

0 comments on commit 542ebb1

Please sign in to comment.