Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array index number to output when casting from json to array. #7397

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 124 additions & 28 deletions edb/edgeql/compiler/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _cast_to_ir(
sql_function=cast.get_from_function(ctx.env.schema),
sql_cast=cast.get_from_cast(ctx.env.schema),
sql_expr=bool(cast.get_code(ctx.env.schema)),
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand All @@ -479,7 +479,7 @@ def _inheritance_cast_to_ir(
sql_function=None,
sql_cast=True,
sql_expr=False,
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand Down Expand Up @@ -663,12 +663,8 @@ def _cast_json_to_tuple(
source_path,
qlast.Constant.boolean(allow_null),
]
if error_message_context := cast_message_context(subctx):
json_object_args.append(qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
))
if error_message_context_ql := get_error_message_context_ql(subctx):
json_object_args.append(error_message_context_ql)
json_objects = qlast.FunctionCall(
func=('__std__', '__tuple_validate_json'),
args=json_object_args
Expand Down Expand Up @@ -698,12 +694,8 @@ def _cast_json_to_tuple(
subctx.collection_cast_info.path_elements.append(cast_element)

json_get_kwargs: dict[str, qlast.Expr] = {}
if error_message_context := cast_message_context(subctx):
json_get_kwargs['detail'] = qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
)
if error_message_context_ql := get_error_message_context_ql(subctx):
json_get_kwargs['detail'] = error_message_context_ql
val_e = qlast.FunctionCall(
func=('__std__', '__json_get_not_null'),
args=[
Expand Down Expand Up @@ -995,12 +987,8 @@ def _cast_json_to_range(
source_path = subctx.create_anchor(ir_set, 'a')

check_args: list[qlast.Expr] = [source_path]
if error_message_context := cast_message_context(subctx):
check_args.append(qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
))
if error_message_context_ql := get_error_message_context_ql(subctx):
check_args.append(error_message_context_ql)
check = qlast.FunctionCall(
func=('__std__', '__range_validate_json'),
args=check_args
Expand Down Expand Up @@ -1274,9 +1262,25 @@ def _cast_array(
if el_type.contains_json(subctx.env.schema):
subctx.inhibit_implicit_limit = True

if subctx.collection_cast_info is not None:
subctx.collection_cast_info.path_elements.append(
(
'array_index',
qlast.FunctionCall(
func=('__std__', 'to_str'),
args=[
astutils.extend_path(enumerated_ref, '0')
],
),
)
)

array_ir = dispatch.compile(correlated_query, ctx=subctx)
assert isinstance(array_ir, irast.Set)

if subctx.collection_cast_info is not None:
subctx.collection_cast_info.path_elements.pop()

if direct_cast is not None:
ctx.env.schema, array_stype = s_types.Array.from_subtypes(
ctx.env.schema, [el_type])
Expand Down Expand Up @@ -1342,7 +1346,7 @@ def _cast_array_literal(
sql_cast=True,
sql_expr=False,
span=span,
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand Down Expand Up @@ -1385,7 +1389,7 @@ def _cast_enum_str_immutable(
sql_function=None,
sql_cast=False,
sql_expr=True,
error_message_context=cast_message_context(ctx),
error_message_context=get_error_message_context_ir(ctx),
)

return setgen.ensure_set(cast_ir, ctx=ctx)
Expand Down Expand Up @@ -1450,7 +1454,7 @@ def _find_object_by_id(
return dispatch.compile(for_query, ctx=subctx)


def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
def get_error_message_context(ctx: context.ContextLevel) -> Optional[str]:
if (
ctx.collection_cast_info is not None
and ctx.collection_cast_info.path_elements
Expand All @@ -1462,8 +1466,13 @@ def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
ctx.collection_cast_info.to_type.get_displayname(ctx.env.schema)
)
path_msg = ''.join(
_collection_element_message_context(path_element)
for path_element in ctx.collection_cast_info.path_elements
element_message_context
for element_message_context in [
_element_message_context(path_element)
for path_element
in ctx.collection_cast_info.path_elements
]
if element_message_context is not None
)
return (
f"while casting '{from_name}' to '{to_name}', {path_msg}"
Expand All @@ -1472,14 +1481,101 @@ def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
return None


def _collection_element_message_context(
path_element: Tuple[str, Optional[str]]
) -> str:
def get_error_message_context_ql(
ctx: context.ContextLevel
) -> Optional[qlast.Expr]:
if (
ctx.collection_cast_info is not None
and ctx.collection_cast_info.path_elements
and any(
e[0] == 'array_index'
for e in ctx.collection_cast_info.path_elements
)
):
from_name = (
ctx.collection_cast_info.from_type.get_displayname(ctx.env.schema)
)
to_name = (
ctx.collection_cast_info.to_type.get_displayname(ctx.env.schema)
)

path_element_messages: list[qlast.Expr] = [
_element_message_context_ql(path_element)
for path_element in ctx.collection_cast_info.path_elements
]

return qlast.FunctionCall(
func=('__std__', 'array_join'),
args=[
qlast.Array(
elements=(
[qlast.Constant.string(
'{"error_message_context": "'
)]
+ [qlast.Constant.string(
f"while casting '{from_name}' to '{to_name}', "
)]
+ path_element_messages
+ [qlast.Constant.string('"}')]
),
),
qlast.Constant.string(""),
],
)

elif error_message_context := get_error_message_context(ctx):
return qlast.Constant.string(
json.dumps({
"error_message_context": error_message_context
})
)

return None


def get_error_message_context_ir(
ctx: context.ContextLevel
) -> Optional[irast.Set]:
if error_message_context_ql := get_error_message_context_ql(ctx):
return dispatch.compile(error_message_context_ql, ctx=ctx)

return None


def _element_message_context(
path_element: Tuple[str, Optional[str | qlast.Expr]]
) -> Optional[str]:
if path_element[0] == 'tuple':
assert isinstance(path_element[1], str)
return f"at tuple element '{path_element[1]}', "
elif path_element[0] == 'array':
return f'in array elements, '
elif path_element[0] == 'array_index':
return None
elif path_element[0] == 'range':
assert isinstance(path_element[1], str)
return f"in range parameter '{path_element[1]}', "
else:
raise NotImplementedError


def _element_message_context_ql(
path_element: Tuple[str, Optional[str | qlast.Expr]]
) -> qlast.Expr:
if path_element[0] == 'array_index':
assert isinstance(path_element[1], qlast.Expr)
return qlast.BinOp(
op='++',
left=qlast.Constant.string("at index "),
right=qlast.BinOp(
op='++',
left=path_element[1],
right=qlast.Constant.string(", "),
)
)

elif element_message_context := _element_message_context(path_element):
return qlast.Constant.string(element_message_context)

else:
raise NotImplementedError
6 changes: 5 additions & 1 deletion edb/edgeql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,16 @@ class CollectionCastInfo(NamedTuple):
from_type: s_types.Type
to_type: s_types.Type

path_elements: list[Tuple[str, Optional[str]]]
path_elements: list[Tuple[str, Optional[str | qlast.Expr]]]
"""Represents a path to the current collection element being cast.
A path element is a tuple of the collection type and an optional
element name. eg. ('tuple', 'a') or ('array', None)
When casting a json array, the path element has the form
('array_index', expr) where expr is a string representation of the index
in the current json array.
The list is shared between the outermost context and all its sub contexts.
When casting a collection, each element's path should be pushed before
entering the "sub-cast" and popped immediately after.
Expand Down
2 changes: 1 addition & 1 deletion edb/edgeql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def compile_TypeCast(

except errors.QueryError as e:
if (
(message_context := casts.cast_message_context(subctx))
(message_context := casts.get_error_message_context(subctx))
and use_message_context
):
e.args = (
Expand Down
5 changes: 5 additions & 0 deletions edb/edgeql/compiler/inference/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,11 @@ def __infer_typecast(
scope_tree: irast.ScopeTreeNode,
ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
if ir.error_message_context is not None:
infer_cardinality(
ir.error_message_context, scope_tree=scope_tree, ctx=ctx
)

card = infer_cardinality(
ir.expr, scope_tree=scope_tree, ctx=ctx,
)
Expand Down
5 changes: 5 additions & 0 deletions edb/edgeql/compiler/inference/multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,11 @@ def __infer_typecast(
scope_tree: irast.ScopeTreeNode,
ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
if ir.error_message_context is not None:
infer_multiplicity(
ir.error_message_context, scope_tree=scope_tree, ctx=ctx
)

return infer_multiplicity(
ir.expr, scope_tree=scope_tree, ctx=ctx,
)
Expand Down
2 changes: 1 addition & 1 deletion edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ class TypeCast(ImmutableExpr):
sql_function: typing.Optional[str] = None
sql_cast: bool
sql_expr: bool
error_message_context: typing.Optional[str] = None
error_message_context: typing.Optional[Set] = None

@property
def typeref(self) -> TypeRef:
Expand Down
10 changes: 2 additions & 8 deletions edb/pgsql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,9 @@ def compile_TypeCast(

pg_expr = dispatch.compile(expr.expr, ctx=ctx)

detail: Optional[pgast.StringConstant] = None
detail: Optional[pgast.BaseExpr] = None
if expr.error_message_context is not None:
detail = pgast.StringConstant(
val=(
'{"error_message_context": "'
+ expr.error_message_context
+ '"}'
)
)
detail = dispatch.compile(expr.error_message_context, ctx=ctx)

if expr.sql_cast:
# Use explicit SQL cast.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_edgeql_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 0, "
r"expected JSON number or null; got JSON string"):
await self.con.query_single(
r"SELECT <array<int64>><json>['asdf']")
Expand All @@ -2469,6 +2470,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 2, "
r"expected JSON number or null; got JSON string"):
await self.con.query_single(
r"SELECT <array<int64>>to_json('[1, 2, \"asdf\"]')")
Expand All @@ -2478,6 +2480,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 0, "
r"expected JSON number or null; got JSON string"):
await self.con.execute("""
SELECT <array<int64>>to_json('["a"]');
Expand All @@ -2497,6 +2500,7 @@ async def test_edgeql_casts_json_11(self):
edgedb.InvalidValueError,
r"array<std::int64>', "
r"in array elements, "
r"at index 2, "
r"invalid null value in cast"):
await self.con.query_single(
r"SELECT <array<int64>>to_json('[1, 2, null]')")
Expand All @@ -2506,6 +2510,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'array<std::json>' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 2, "
r"invalid null value in cast"):
await self.con.query_single(
r"SELECT <array<int64>><array<json>>to_json('[1, 2, null]')")
Expand All @@ -2525,6 +2530,7 @@ async def test_edgeql_casts_json_11(self):
r"to 'tuple<array<std::str>>', "
r"at tuple element '0', "
r"in array elements, "
r"at index 0, "
r"invalid null value in cast"):
await self.con.query_single(
r"select <tuple<array<str>>>to_json('[[null]]')")
Expand All @@ -2543,6 +2549,7 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<std::int64>', "
r"in array elements, "
r"at index 0, "
r"expected JSON number or null; got JSON object"):
await self.con.execute("""
SELECT <array<int64>>to_json('[{"a": 1}]');
Expand All @@ -2554,8 +2561,10 @@ async def test_edgeql_casts_json_11(self):
r"while casting 'std::json' "
r"to 'array<tuple<array<std::str>>>', "
r"in array elements, "
r"at index 0, "
r"at tuple element '0', "
r"in array elements, "
r"at index 0, "
r"expected JSON string or null; got JSON number"):
await self.con.execute("""
SELECT <array<tuple<array<str>>>>to_json('[[[1]]]');
Expand Down