-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement function inlining. #7713
Changes from 16 commits
9fd98ba
84215b8
9b766c6
719e983
66cfc6f
211915d
3ef7941
a1b08a0
17237a8
9592273
f5672a9
2ce748a
686f45c
fe48d68
45dbd62
a82fb74
a35e431
d19da8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,14 +37,17 @@ | |
) | ||
|
||
from edb import errors | ||
from edb.common import ast | ||
from edb.common import parsing | ||
from edb.common.typeutils import not_none | ||
|
||
from edb.ir import ast as irast | ||
from edb.ir import staeval | ||
from edb.ir import utils as irutils | ||
from edb.ir import typeutils as irtyputils | ||
|
||
from edb.schema import constraints as s_constr | ||
from edb.schema import delta as sd | ||
from edb.schema import functions as s_func | ||
from edb.schema import modules as s_mod | ||
from edb.schema import name as sn | ||
|
@@ -196,6 +199,26 @@ def compile_FunctionCall( | |
func = matched_call.func | ||
assert isinstance(func, s_func.Function) | ||
|
||
inline_func = None | ||
if ( | ||
func.get_language(ctx.env.schema) == qlast.Language.EdgeQL | ||
and func.get_is_inlined(ctx.env.schema) | ||
): | ||
inline_func = s_func.compile_function_inline( | ||
schema=ctx.env.schema, | ||
context=sd.CommandContext( | ||
schema=ctx.env.schema, | ||
), | ||
body=not_none(func.get_nativecode(ctx.env.schema)), | ||
func_name=func.get_name(ctx.env.schema), | ||
params=func.get_params(ctx.env.schema), | ||
language=not_none(func.get_language(ctx.env.schema)), | ||
return_type=func.get_return_type(ctx.env.schema), | ||
return_typemod=func.get_return_typemod(ctx.env.schema), | ||
track_schema_ref_exprs=False, | ||
inlining_context=ctx, | ||
) | ||
|
||
# Record this node in the list of potential DML expressions. | ||
if func.get_has_dml(env.schema): | ||
ctx.env.dml_exprs.append(expr) | ||
|
@@ -246,7 +269,7 @@ def compile_FunctionCall( | |
|
||
matched_func_initial_value = func.get_initial_value(env.schema) | ||
|
||
final_args = finalize_args( | ||
final_args, param_name_to_arg_key = finalize_args( | ||
matched_call, | ||
guessed_typemods=typemods, | ||
is_polymorphic=is_polymorphic, | ||
|
@@ -331,11 +354,15 @@ def compile_FunctionCall( | |
rtype, env=env, | ||
), | ||
typemod=matched_call.func.get_return_typemod(env.schema), | ||
has_empty_variadic=matched_call.has_empty_variadic, | ||
has_empty_variadic=(matched_call.variadic_arg_count == 0), | ||
variadic_param_type=variadic_param_type, | ||
func_initial_value=func_initial_value, | ||
tuple_path_ids=tuple_path_ids, | ||
impl_is_strict=func.get_impl_is_strict(env.schema), | ||
impl_is_strict=( | ||
func.get_impl_is_strict(env.schema) | ||
# Inlined functions should always check for null arguments. | ||
and not inline_func | ||
), | ||
prefer_subquery_args=func.get_prefer_subquery_args(env.schema), | ||
is_singleton_set_of=func.get_is_singleton_set_of(env.schema), | ||
global_args=global_args, | ||
|
@@ -345,6 +372,88 @@ def compile_FunctionCall( | |
# Apply special function handling | ||
if special_func := _SPECIAL_FUNCTIONS.get(str(func_name)): | ||
res = special_func(fcall, ctx=ctx) | ||
elif inline_func: | ||
res = fcall | ||
|
||
# TODO: Global parameters still use the implicit globals parameter. | ||
# They should be directly substituted in whenever possible. | ||
|
||
inline_args: dict[str, irast.CallArg | irast.Set] = {} | ||
|
||
# Collect non-default call args to inline | ||
for param_shortname, arg_key in param_name_to_arg_key.items(): | ||
if ( | ||
isinstance(arg_key, int) | ||
and matched_call.variadic_arg_id is not None | ||
and arg_key >= matched_call.variadic_arg_id | ||
): | ||
continue | ||
|
||
arg = final_args[arg_key] | ||
if arg.is_default: | ||
continue | ||
|
||
inline_args[param_shortname] = arg | ||
|
||
# Package variadic arguments into an array | ||
if variadic_param is not None: | ||
assert variadic_param_type is not None | ||
assert matched_call.variadic_arg_id is not None | ||
assert matched_call.variadic_arg_count is not None | ||
|
||
param_shortname = variadic_param.get_parameter_name(env.schema) | ||
inline_args[param_shortname] = ir_set = setgen.ensure_set( | ||
irast.Array( | ||
elements=[ | ||
final_args[arg_key].expr | ||
for arg_key in range( | ||
matched_call.variadic_arg_id, | ||
matched_call.variadic_arg_id | ||
+ matched_call.variadic_arg_count | ||
) | ||
], | ||
typeref=variadic_param_type, | ||
), | ||
ctx=ctx, | ||
) | ||
|
||
# Compile default args if necessary | ||
for param in matched_func_params.objects(env.schema): | ||
param_shortname = param.get_parameter_name(env.schema) | ||
|
||
if param_shortname in inline_args: | ||
continue | ||
|
||
else: | ||
# Missing named only args have their default values already | ||
# compiled in try_bind_call_args. | ||
if bound_args := [ | ||
bound_arg | ||
for bound_arg in matched_call.args | ||
if bound_arg.param == param and bound_arg.is_default | ||
]: | ||
assert len(bound_args) == 1 | ||
inline_args[param_shortname] = bound_args[0].val | ||
continue | ||
|
||
# Check if default is available | ||
p_default = param.get_default(env.schema) | ||
if p_default is None: | ||
continue | ||
|
||
# Compile default | ||
assert isinstance(param, s_func.Parameter) | ||
p_ir_default = param.get_ir_default( | ||
schema=env.schema, | ||
inlining_context=ctx, | ||
) | ||
inline_args[param_shortname] = ( | ||
p_ir_default.expr | ||
) | ||
|
||
argument_inliner = ArgumentInliner(inline_args, ctx=ctx) | ||
res.body = argument_inliner.visit(inline_func) | ||
|
||
else: | ||
res = fcall | ||
|
||
|
@@ -356,14 +465,80 @@ def compile_FunctionCall( | |
for arg in res.args.values(): | ||
pathctx.register_set_in_scope( | ||
arg.expr, | ||
optional=arg.param_typemod == ft.TypeModifier.OptionalType, | ||
optional=( | ||
arg.param_typemod == ft.TypeModifier.OptionalType | ||
), | ||
ctx=ctx, | ||
) | ||
|
||
ir_set = setgen.ensure_set(res, typehint=rtype, path_id=path_id, ctx=ctx) | ||
return stmt.maybe_add_view(ir_set, ctx=ctx) | ||
|
||
|
||
class ArgumentInliner(ast.NodeTransformer): | ||
|
||
mapped_args: dict[irast.PathId, irast.PathId] | ||
Comment on lines
+473
to
+475
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize now that it should be possible (and probably nicer) to put the inlined arguments in the context somehow while compiling the inlined function, and inserting them directly when compiling Parameters, rather then substituting them in after the fact. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided to not go down that path since the backend compiler does stuff with args too. Probably possible though. |
||
inlined_arg_keys: list[int | str] | ||
|
||
def __init__( | ||
self, | ||
inline_args: dict[str, irast.CallArg | irast.Set], | ||
ctx: context.ContextLevel, | ||
) -> None: | ||
super().__init__() | ||
self.inline_args = inline_args | ||
self.ctx = ctx | ||
self.mapped_args = {} | ||
|
||
def visit_Set(self, node: irast.Set) -> irast.Base: | ||
if ( | ||
isinstance(node.expr, irast.Parameter) | ||
and node.expr.name in self.inline_args | ||
): | ||
arg = self.inline_args[node.expr.name] | ||
if isinstance(arg, irast.CallArg): | ||
# Inline param as an expr ref. The pg compiler will find the | ||
# appropriate rvar. | ||
self.mapped_args[node.path_id] = arg.expr.path_id | ||
return setgen.ensure_set( | ||
irast.InlinedParameterExpr( | ||
typeref=arg.expr.typeref, | ||
required=node.expr.required, | ||
is_global=node.expr.is_global, | ||
), | ||
path_id=arg.expr.path_id, | ||
ctx=self.ctx, | ||
) | ||
else: | ||
# Directly inline the set. | ||
# Used for default values, which are constants. | ||
return arg | ||
|
||
elif isinstance(node.expr, irast.Pointer): | ||
# The set and source path ids must match in order for the pointer | ||
# to find the correct rvar. If a pointer's source path was modified | ||
# because of an inlined parameter, modify the pointer's path as | ||
# well. | ||
prev_source_path_id = node.expr.source.path_id | ||
result = cast(irast.Set, self.generic_visit(node)) | ||
|
||
if prev_source_path_id in self.mapped_args: | ||
result = setgen.new_set_from_set( | ||
result, | ||
path_id=irtyputils.replace_pathid_prefix( | ||
result.path_id, | ||
prev_source_path_id, | ||
self.mapped_args[prev_source_path_id], | ||
), | ||
ctx=self.ctx, | ||
) | ||
self.mapped_args[node.path_id] = result.path_id | ||
|
||
return result | ||
|
||
return cast(irast.Base, self.generic_visit(node)) | ||
|
||
|
||
class _SpecialCaseFunc(Protocol): | ||
def __call__( | ||
self, call: irast.FunctionCall, *, ctx: context.ContextLevel | ||
|
@@ -618,7 +793,7 @@ def compile_operator( | |
matched_rtype.is_polymorphic(env.schema) | ||
) | ||
|
||
final_args = finalize_args( | ||
final_args, _ = finalize_args( | ||
matched_call, | ||
actual_typemods=actual_typemods, | ||
guessed_typemods=typemods, | ||
|
@@ -856,9 +1031,10 @@ def finalize_args( | |
guessed_typemods: Dict[Union[int, str], ft.TypeModifier], | ||
is_polymorphic: bool = False, | ||
ctx: context.ContextLevel, | ||
) -> Dict[Union[int, str], irast.CallArg]: | ||
) -> tuple[dict[int | str, irast.CallArg], dict[str, int | str]]: | ||
|
||
args: Dict[Union[int, str], irast.CallArg] = {} | ||
args: dict[int | str, irast.CallArg] = {} | ||
param_name_to_arg: dict[str, int | str] = {} | ||
position_index: int = 0 | ||
|
||
for i, barg in enumerate(bound_call.args): | ||
|
@@ -867,6 +1043,7 @@ def finalize_args( | |
arg_type_path_id: Optional[irast.PathId] = None | ||
if param is None: | ||
# defaults bitmask | ||
param_name_to_arg['__defaults_mask__'] = -1 | ||
args[-1] = irast.CallArg( | ||
expr=arg_val, | ||
param_typemod=ft.TypeModifier.SingletonType, | ||
|
@@ -972,14 +1149,21 @@ def finalize_args( | |
|
||
arg = irast.CallArg(expr=arg_val, expr_type_path_id=arg_type_path_id, | ||
is_default=barg.is_default, param_typemod=param_mod) | ||
param_shortname = param.get_parameter_name(ctx.env.schema) | ||
if param_kind is ft.ParameterKind.NamedOnlyParam: | ||
param_shortname = param.get_parameter_name(ctx.env.schema) | ||
args[param_shortname] = arg | ||
param_name_to_arg[param_shortname] = param_shortname | ||
else: | ||
args[position_index] = arg | ||
if ( | ||
# Variadic args will all have the same name, but different | ||
# indexes. We want to take the first index. | ||
param_shortname not in param_name_to_arg | ||
): | ||
param_name_to_arg[param_shortname] = position_index | ||
position_index += 1 | ||
|
||
return args | ||
return args, param_name_to_arg | ||
|
||
|
||
@_special_case('ext::ai::search') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't use these entrypoints anymore, right? Can we drop the
inlining_context
from them?