Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement function inlining. #7713

Merged
merged 18 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions edb/edgeql/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
from edb.ir import ast as irast
from edb.ir import staeval as ireval

from . import context
from . import dispatch as dispatch_mod
from . import inference as inference_mod
from . import normalization as norm_mod
Expand Down Expand Up @@ -201,6 +202,7 @@ def compile_ast_to_ir(
*,
script_info: Optional[irast.ScriptInfo] = None,
options: Optional[CompilerOptions] = None,
inlining_context: Optional[context.ContextLevel] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't use these entrypoints anymore, right? Can we drop the inlining_context from them?

) -> irast.Statement:
pass

Expand All @@ -212,6 +214,7 @@ def compile_ast_to_ir(
*,
script_info: Optional[irast.ScriptInfo] = None,
options: Optional[CompilerOptions] = None,
inlining_context: Optional[context.ContextLevel] = None,
) -> irast.ConfigCommand:
pass

Expand All @@ -223,6 +226,7 @@ def compile_ast_to_ir(
*,
script_info: Optional[irast.ScriptInfo] = None,
options: Optional[CompilerOptions] = None,
inlining_context: Optional[context.ContextLevel] = None,
) -> irast.Statement | irast.ConfigCommand:
pass

Expand All @@ -234,6 +238,7 @@ def compile_ast_to_ir(
*,
script_info: Optional[irast.ScriptInfo] = None,
options: Optional[CompilerOptions] = None,
inlining_context: Optional[context.ContextLevel] = None,
) -> irast.Statement | irast.ConfigCommand:
"""Compile given EdgeQL AST into EdgeDB IR.

Expand Down Expand Up @@ -273,7 +278,11 @@ def compile_ast_to_ir(
debug.header('EdgeQL AST')
debug.dump(tree, schema=schema)

ctx = stmtctx_mod.init_context(schema=schema, options=options)
ctx = stmtctx_mod.init_context(
schema=schema,
options=options,
inlining_context=inlining_context,
)

if isinstance(tree, qlast.Expr) and ctx.implicit_limit:
tree = qlast.SelectQuery(result=tree, implicit=True)
Expand Down Expand Up @@ -317,6 +326,7 @@ def compile_ast_fragment_to_ir(
schema: s_schema.Schema,
*,
options: Optional[CompilerOptions] = None,
inlining_context: Optional[context.ContextLevel] = None,
) -> irast.Statement:
"""Compile given EdgeQL AST fragment into EdgeDB IR.

Expand All @@ -342,7 +352,9 @@ def compile_ast_fragment_to_ir(
if options is None:
options = CompilerOptions()

ctx = stmtctx_mod.init_context(schema=schema, options=options)
ctx = stmtctx_mod.init_context(
schema=schema, options=options, inlining_context=inlining_context
)
ir_set = dispatch_mod.compile(tree, ctx=ctx)

result_type = ctx.env.set_types[ir_set]
Expand Down
202 changes: 193 additions & 9 deletions edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@
)

from edb import errors
from edb.common import ast
from edb.common import parsing
from edb.common.typeutils import not_none

from edb.ir import ast as irast
from edb.ir import staeval
from edb.ir import utils as irutils
from edb.ir import typeutils as irtyputils

from edb.schema import constraints as s_constr
from edb.schema import delta as sd
from edb.schema import functions as s_func
from edb.schema import modules as s_mod
from edb.schema import name as sn
Expand Down Expand Up @@ -196,6 +199,26 @@ def compile_FunctionCall(
func = matched_call.func
assert isinstance(func, s_func.Function)

inline_func = None
if (
func.get_language(ctx.env.schema) == qlast.Language.EdgeQL
and func.get_is_inlined(ctx.env.schema)
):
inline_func = s_func.compile_function_inline(
schema=ctx.env.schema,
context=sd.CommandContext(
schema=ctx.env.schema,
),
body=not_none(func.get_nativecode(ctx.env.schema)),
func_name=func.get_name(ctx.env.schema),
params=func.get_params(ctx.env.schema),
language=not_none(func.get_language(ctx.env.schema)),
return_type=func.get_return_type(ctx.env.schema),
return_typemod=func.get_return_typemod(ctx.env.schema),
track_schema_ref_exprs=False,
inlining_context=ctx,
)

# Record this node in the list of potential DML expressions.
if func.get_has_dml(env.schema):
ctx.env.dml_exprs.append(expr)
Expand Down Expand Up @@ -246,7 +269,7 @@ def compile_FunctionCall(

matched_func_initial_value = func.get_initial_value(env.schema)

final_args = finalize_args(
final_args, param_name_to_arg_key = finalize_args(
matched_call,
guessed_typemods=typemods,
is_polymorphic=is_polymorphic,
Expand Down Expand Up @@ -331,11 +354,15 @@ def compile_FunctionCall(
rtype, env=env,
),
typemod=matched_call.func.get_return_typemod(env.schema),
has_empty_variadic=matched_call.has_empty_variadic,
has_empty_variadic=(matched_call.variadic_arg_count == 0),
variadic_param_type=variadic_param_type,
func_initial_value=func_initial_value,
tuple_path_ids=tuple_path_ids,
impl_is_strict=func.get_impl_is_strict(env.schema),
impl_is_strict=(
func.get_impl_is_strict(env.schema)
# Inlined functions should always check for null arguments.
and not inline_func
),
prefer_subquery_args=func.get_prefer_subquery_args(env.schema),
is_singleton_set_of=func.get_is_singleton_set_of(env.schema),
global_args=global_args,
Expand All @@ -345,6 +372,88 @@ def compile_FunctionCall(
# Apply special function handling
if special_func := _SPECIAL_FUNCTIONS.get(str(func_name)):
res = special_func(fcall, ctx=ctx)
elif inline_func:
res = fcall

# TODO: Global parameters still use the implicit globals parameter.
# They should be directly substituted in whenever possible.

inline_args: dict[str, irast.CallArg | irast.Set] = {}

# Collect non-default call args to inline
for param_shortname, arg_key in param_name_to_arg_key.items():
if (
isinstance(arg_key, int)
and matched_call.variadic_arg_id is not None
and arg_key >= matched_call.variadic_arg_id
):
continue

arg = final_args[arg_key]
if arg.is_default:
continue

inline_args[param_shortname] = arg

# Package variadic arguments into an array
if variadic_param is not None:
assert variadic_param_type is not None
assert matched_call.variadic_arg_id is not None
assert matched_call.variadic_arg_count is not None

param_shortname = variadic_param.get_parameter_name(env.schema)
inline_args[param_shortname] = ir_set = setgen.ensure_set(
irast.Array(
elements=[
final_args[arg_key].expr
for arg_key in range(
matched_call.variadic_arg_id,
matched_call.variadic_arg_id
+ matched_call.variadic_arg_count
)
],
typeref=variadic_param_type,
),
ctx=ctx,
)

# Compile default args if necessary
for param in matched_func_params.objects(env.schema):
param_shortname = param.get_parameter_name(env.schema)

if param_shortname in inline_args:
continue

else:
# Missing named only args have their default values already
# compiled in try_bind_call_args.
if bound_args := [
bound_arg
for bound_arg in matched_call.args
if bound_arg.param == param and bound_arg.is_default
]:
assert len(bound_args) == 1
inline_args[param_shortname] = bound_args[0].val
continue

# Check if default is available
p_default = param.get_default(env.schema)
if p_default is None:
continue

# Compile default
assert isinstance(param, s_func.Parameter)
p_ir_default = param.get_ir_default(
schema=env.schema,
inlining_context=ctx,
)
inline_args[param_shortname] = (
p_ir_default.expr
)

argument_inliner = ArgumentInliner(inline_args, ctx=ctx)
res.body = argument_inliner.visit(inline_func)

else:
res = fcall

Expand All @@ -356,14 +465,80 @@ def compile_FunctionCall(
for arg in res.args.values():
pathctx.register_set_in_scope(
arg.expr,
optional=arg.param_typemod == ft.TypeModifier.OptionalType,
optional=(
arg.param_typemod == ft.TypeModifier.OptionalType
),
ctx=ctx,
)

ir_set = setgen.ensure_set(res, typehint=rtype, path_id=path_id, ctx=ctx)
return stmt.maybe_add_view(ir_set, ctx=ctx)


class ArgumentInliner(ast.NodeTransformer):

mapped_args: dict[irast.PathId, irast.PathId]
Comment on lines +473 to +475
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize now that it should be possible (and probably nicer) to put the inlined arguments in the context somehow while compiling the inlined function, and inserting them directly when compiling Parameters, rather then substituting them in after the fact.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to not go down that path since the backend compiler does stuff with args too. Probably possible though.

inlined_arg_keys: list[int | str]

def __init__(
self,
inline_args: dict[str, irast.CallArg | irast.Set],
ctx: context.ContextLevel,
) -> None:
super().__init__()
self.inline_args = inline_args
self.ctx = ctx
self.mapped_args = {}

def visit_Set(self, node: irast.Set) -> irast.Base:
if (
isinstance(node.expr, irast.Parameter)
and node.expr.name in self.inline_args
):
arg = self.inline_args[node.expr.name]
if isinstance(arg, irast.CallArg):
# Inline param as an expr ref. The pg compiler will find the
# appropriate rvar.
self.mapped_args[node.path_id] = arg.expr.path_id
return setgen.ensure_set(
irast.InlinedParameterExpr(
typeref=arg.expr.typeref,
required=node.expr.required,
is_global=node.expr.is_global,
),
path_id=arg.expr.path_id,
ctx=self.ctx,
)
else:
# Directly inline the set.
# Used for default values, which are constants.
return arg

elif isinstance(node.expr, irast.Pointer):
# The set and source path ids must match in order for the pointer
# to find the correct rvar. If a pointer's source path was modified
# because of an inlined parameter, modify the pointer's path as
# well.
prev_source_path_id = node.expr.source.path_id
result = cast(irast.Set, self.generic_visit(node))

if prev_source_path_id in self.mapped_args:
result = setgen.new_set_from_set(
result,
path_id=irtyputils.replace_pathid_prefix(
result.path_id,
prev_source_path_id,
self.mapped_args[prev_source_path_id],
),
ctx=self.ctx,
)
self.mapped_args[node.path_id] = result.path_id

return result

return cast(irast.Base, self.generic_visit(node))


class _SpecialCaseFunc(Protocol):
def __call__(
self, call: irast.FunctionCall, *, ctx: context.ContextLevel
Expand Down Expand Up @@ -618,7 +793,7 @@ def compile_operator(
matched_rtype.is_polymorphic(env.schema)
)

final_args = finalize_args(
final_args, _ = finalize_args(
matched_call,
actual_typemods=actual_typemods,
guessed_typemods=typemods,
Expand Down Expand Up @@ -856,9 +1031,10 @@ def finalize_args(
guessed_typemods: Dict[Union[int, str], ft.TypeModifier],
is_polymorphic: bool = False,
ctx: context.ContextLevel,
) -> Dict[Union[int, str], irast.CallArg]:
) -> tuple[dict[int | str, irast.CallArg], dict[str, int | str]]:

args: Dict[Union[int, str], irast.CallArg] = {}
args: dict[int | str, irast.CallArg] = {}
param_name_to_arg: dict[str, int | str] = {}
position_index: int = 0

for i, barg in enumerate(bound_call.args):
Expand All @@ -867,6 +1043,7 @@ def finalize_args(
arg_type_path_id: Optional[irast.PathId] = None
if param is None:
# defaults bitmask
param_name_to_arg['__defaults_mask__'] = -1
args[-1] = irast.CallArg(
expr=arg_val,
param_typemod=ft.TypeModifier.SingletonType,
Expand Down Expand Up @@ -972,14 +1149,21 @@ def finalize_args(

arg = irast.CallArg(expr=arg_val, expr_type_path_id=arg_type_path_id,
is_default=barg.is_default, param_typemod=param_mod)
param_shortname = param.get_parameter_name(ctx.env.schema)
if param_kind is ft.ParameterKind.NamedOnlyParam:
param_shortname = param.get_parameter_name(ctx.env.schema)
args[param_shortname] = arg
param_name_to_arg[param_shortname] = param_shortname
else:
args[position_index] = arg
if (
# Variadic args will all have the same name, but different
# indexes. We want to take the first index.
param_shortname not in param_name_to_arg
):
param_name_to_arg[param_shortname] = position_index
position_index += 1

return args
return args, param_name_to_arg


@_special_case('ext::ai::search')
Expand Down
10 changes: 10 additions & 0 deletions edb/edgeql/compiler/inference/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,16 @@ def __infer_param(
return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_inlined_param(
ir: irast.InlinedParameterExpr,
*,
scope_tree: irast.ScopeTreeNode,
ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_const_set(
ir: irast.ConstantSet,
Expand Down
Loading