Skip to content

Commit

Permalink
Pack variadic arguments into an array.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Sep 10, 2024
1 parent 664a49e commit 3be4bf1
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 13 deletions.
8 changes: 6 additions & 2 deletions edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def compile_FunctionCall(
rtype, env=env,
),
typemod=matched_call.func.get_return_typemod(env.schema),
has_empty_variadic=matched_call.has_empty_variadic,
variadic_arg_id=matched_call.variadic_arg_id,
variadic_arg_count=matched_call.variadic_arg_count,
variadic_param_type=variadic_param_type,
func_initial_value=func_initial_value,
tuple_path_ids=tuple_path_ids,
Expand Down Expand Up @@ -1085,7 +1086,10 @@ def finalize_args(
param_name_to_arg[param_shortname] = param_shortname
else:
args[position_index] = arg
if not barg.is_default:
if (
not barg.is_default
and param_shortname not in param_name_to_arg
):
param_name_to_arg[param_shortname] = position_index
position_index += 1

Expand Down
24 changes: 18 additions & 6 deletions edb/edgeql/compiler/polyres.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class BoundCall(NamedTuple):
args: List[BoundArg]
null_args: Set[str]
return_type: s_types.Type
has_empty_variadic: bool
variadic_arg_id: Optional[int]
variadic_arg_count: Optional[int]


_VARIADIC = ft.ParameterKind.VariadicParam
Expand Down Expand Up @@ -309,7 +310,8 @@ def _get_cast_distance(
ctx.env.options.func_params.has_polymorphic(schema)
)

has_empty_variadic = False
variadic_arg_id: Optional[int] = None
variadic_arg_count: Optional[int] = None
no_args_call = not args and not kwargs
has_inlined_defaults = func.has_inlined_defaults(schema)

Expand All @@ -330,8 +332,8 @@ def _get_cast_distance(
ctx=ctx)
bargs = [BoundArg(None, bytes_t, argval, bytes_t, 0, -1)]
return BoundCall(
func, bargs, set(),
return_type, False)
func, bargs, set(), return_type, None, None
)
else:
# No match: `func` is a function without parameters
# being called with some arguments.
Expand Down Expand Up @@ -430,6 +432,9 @@ def _get_cast_distance(
BoundArg(param, param_type, arg_val, arg_type, cd,
ai + di))

variadic_arg_id = ai - 1
variadic_arg_count = nargs - ai + 1

break

cd = _get_cast_distance(arg_val, arg_type, param_type)
Expand Down Expand Up @@ -458,7 +463,8 @@ def _get_cast_distance(
bound_args_prep.append(MissingArg(param, param_type))

elif param_kind is _VARIADIC:
has_empty_variadic = True
variadic_arg_id = i
variadic_arg_count = 0

elif param_kind is _NAMED_ONLY:
# impossible condition
Expand Down Expand Up @@ -592,7 +598,13 @@ def _get_cast_distance(
)

return BoundCall(
func, bound_param_args, null_args, return_type, has_empty_variadic)
func,
bound_param_args,
null_args,
return_type,
variadic_arg_id,
variadic_arg_count,
)


def compile_arg(
Expand Down
10 changes: 7 additions & 3 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,13 @@ class FunctionCall(Call):
# handle empty set
func_initial_value: typing.Optional[Set] = None

# True if the bound function has a variadic parameter and
# there are no arguments that are bound to it.
has_empty_variadic: bool = False
# If the bound function has a variadic parameter, this will be the index
# of the variadic argument within the argument list.
variadic_arg_id: typing.Optional[int] = None

# If the bound function has a variadic parameter, this will be the number
# of variadic arguments
variadic_arg_count: typing.Optional[int] = None

# The underlying SQL function has OUT parameters.
sql_func_has_out_params: bool = False
Expand Down
2 changes: 1 addition & 1 deletion edb/pgsql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def compile_FunctionCall(

args, maybe_null = _compile_call_args(expr, ctx=ctx)

if expr.has_empty_variadic and expr.variadic_param_type is not None:
if expr.variadic_arg_count == 0 and expr.variadic_param_type is not None:
var = pgast.TypeCast(
arg=pgast.ArrayExpr(elements=[]),
type_name=pgast.TypeName(
Expand Down
20 changes: 19 additions & 1 deletion edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,7 +3423,7 @@ def _compile_call_args(

if (
isinstance(expr, irast.FunctionCall)
and expr.has_empty_variadic
and expr.variadic_arg_count == 0
and expr.variadic_param_type is not None
):
var = pgast.TypeCast(
Expand Down Expand Up @@ -3483,7 +3483,25 @@ def process_set_as_func_expr(
newctx.inlined_args = {
arg_key: args[index]
for arg_key, index in arg_indexes.items()
if (
isinstance(arg_key, str)
or expr.variadic_arg_id is None
or arg_key < expr.variadic_arg_id
)
}
if (
expr.variadic_arg_id is not None
and expr.variadic_arg_count is not None
):
newctx.inlined_args[expr.variadic_arg_id] = pgast.ArrayExpr(
elements=[
args[arg_indexes[index]]
for index in range(
expr.variadic_arg_id,
expr.variadic_arg_id + expr.variadic_arg_count
)
]
)
set_expr = dispatch.compile(expr.body, ctx=newctx)

else:
Expand Down

0 comments on commit 3be4bf1

Please sign in to comment.