Skip to content

Commit

Permalink
Add inlining functions with parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Sep 4, 2024
1 parent d7c408f commit fe66209
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 13 deletions.
70 changes: 64 additions & 6 deletions edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)

from edb import errors
from edb.common import ast
from edb.common import parsing
from edb.common.typeutils import not_none

Expand All @@ -45,6 +46,7 @@
from edb.ir import utils as irutils

from edb.schema import constraints as s_constr
from edb.schema import delta as sd
from edb.schema import functions as s_func
from edb.schema import modules as s_mod
from edb.schema import name as sn
Expand Down Expand Up @@ -196,6 +198,32 @@ def compile_FunctionCall(
func = matched_call.func
assert isinstance(func, s_func.Function)

inline_func = None
if (
func.get_language(ctx.env.schema) == qlast.Language.EdgeQL
and func.get_shortname(ctx.env.schema) in [
sn.QualName('default', 'one'),
sn.QualName('default', 'id'),
sn.QualName('default', 'inc'),
]
):
inline_func = s_func.compile_function(
schema=ctx.env.schema,
context=sd.CommandContext(
# Probably not correct. Need to store modaliases while compiling
# functions?
modaliases=ctx.modaliases,
schema=ctx.env.schema,
),
body=func.get_nativecode(ctx.env.schema),
func_name=func.get_name(ctx.env.schema),
params=func.get_params(ctx.env.schema),
language=func.get_language(ctx.env.schema),
return_type=func.get_return_type(ctx.env.schema),
return_typemod=func.get_return_typemod(ctx.env.schema),
track_schema_ref_exprs=False,
)

# Record this node in the list of potential DML expressions.
if func.get_has_dml(env.schema):
ctx.env.dml_exprs.append(expr)
Expand Down Expand Up @@ -246,7 +274,7 @@ def compile_FunctionCall(

matched_func_initial_value = func.get_initial_value(env.schema)

final_args = finalize_args(
final_args, param_name_to_arg = finalize_args(
matched_call,
guessed_typemods=typemods,
is_polymorphic=is_polymorphic,
Expand Down Expand Up @@ -345,6 +373,12 @@ def compile_FunctionCall(
# Apply special function handling
if special_func := _SPECIAL_FUNCTIONS.get(str(func_name)):
res = special_func(fcall, ctx=ctx)
elif inline_func:
res = fcall
inline_func_expr = inline_func.irast.expr.expr
res.body = ArgumentInliner(param_name_to_arg, final_args).visit(
inline_func_expr
)
else:
res = fcall

Expand All @@ -364,6 +398,27 @@ def compile_FunctionCall(
return stmt.maybe_add_view(ir_set, ctx=ctx)


class ArgumentInliner(ast.NodeTransformer):

def __init__(
self,
param_name_to_arg: dict[str, int | str],
final_args: dict[int | str, irast.CallArg],
) -> None:
super().__init__()
self.param_name_to_arg = param_name_to_arg
self.final_args = final_args

def visit_Parameter(self, node: irast.Parameter) -> irast.Base:
if node.name in self.param_name_to_arg:
return irast.InlinedParameter(
name=self.param_name_to_arg[node.name],
required=node.required,
is_global=node.is_global,
)
return node


class _SpecialCaseFunc(Protocol):
def __call__(
self, call: irast.FunctionCall, *, ctx: context.ContextLevel
Expand Down Expand Up @@ -618,7 +673,7 @@ def compile_operator(
matched_rtype.is_polymorphic(env.schema)
)

final_args = finalize_args(
final_args, _ = finalize_args(
matched_call,
actual_typemods=actual_typemods,
guessed_typemods=typemods,
Expand Down Expand Up @@ -856,9 +911,10 @@ def finalize_args(
guessed_typemods: Dict[Union[int, str], ft.TypeModifier],
is_polymorphic: bool = False,
ctx: context.ContextLevel,
) -> Dict[Union[int, str], irast.CallArg]:
) -> tuple[dict[int | str, irast.CallArg], dict[str, int | str]]:

args: Dict[Union[int, str], irast.CallArg] = {}
args: dict[int | str, irast.CallArg] = {}
param_name_to_arg: dict[str, int | str] = {}
position_index: int = 0

for i, barg in enumerate(bound_call.args):
Expand Down Expand Up @@ -972,14 +1028,16 @@ def finalize_args(

arg = irast.CallArg(expr=arg_val, expr_type_path_id=arg_type_path_id,
is_default=barg.is_default, param_typemod=param_mod)
param_shortname = param.get_parameter_name(ctx.env.schema)
if param_kind is ft.ParameterKind.NamedOnlyParam:
param_shortname = param.get_parameter_name(ctx.env.schema)
args[param_shortname] = arg
param_name_to_arg[param_shortname] = param_shortname
else:
args[position_index] = arg
param_name_to_arg[param_shortname] = position_index
position_index += 1

return args
return args, param_name_to_arg


@_special_case('ext::ai::search')
Expand Down
10 changes: 10 additions & 0 deletions edb/edgeql/compiler/inference/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,16 @@ def __infer_param(
return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_inlined_param(
ir: irast.InlinedParameter,
*,
scope_tree: irast.ScopeTreeNode,
ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_const_set(
ir: irast.ConstantSet,
Expand Down
10 changes: 10 additions & 0 deletions edb/edgeql/compiler/inference/multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,16 @@ def __infer_param(
return UNIQUE


@_infer_multiplicity.register
def __infer_inlined_param(
ir: irast.InlinedParameter,
*,
scope_tree: irast.ScopeTreeNode,
ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
return UNIQUE


@_infer_multiplicity.register
def __infer_const_set(
ir: irast.ConstantSet,
Expand Down
8 changes: 8 additions & 0 deletions edb/edgeql/compiler/inference/volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ def __infer_param(
return STABLE if ir.is_global else IMMUTABLE


@_infer_volatility_inner.register
def __infer_inlined_param(
ir: irast.InlinedParameter,
env: context.Environment,
) -> InferredVolatility:
return STABLE if ir.is_global else IMMUTABLE


@_infer_volatility_inner.register
def __infer_const_set(
ir: irast.ConstantSet,
Expand Down
8 changes: 8 additions & 0 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,14 @@ def is_global(self) -> bool:
return self.is_implicit_global is not None


class InlinedParameter(ImmutableBase):

# int for positional argument, str for named argument
name: int | str
required: bool
is_global: bool


class TupleElement(ImmutableBase):

name: str
Expand Down
8 changes: 8 additions & 0 deletions edb/pgsql/compiler/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ class CompilerContextLevel(compiler.ContextLevel):
#: needed by DML.
shapes_needed_by_dml: Set[irast.Set]

#: When a function has an inlined body, the compiled parameters are
#: subsituted in during pg compilation.
inlined_params: dict[int | str, pgast.BaseExpr]

def __init__(
self,
prevlevel: Optional[CompilerContextLevel],
Expand Down Expand Up @@ -388,6 +392,8 @@ def __init__(

self.trigger_mode = False

self.inlined_params = {}

else:
self.env = prevlevel.env
self.argmap = prevlevel.argmap
Expand Down Expand Up @@ -429,6 +435,8 @@ def __init__(

self.trigger_mode = prevlevel.trigger_mode

self.inlined_params = prevlevel.inlined_params

if mode is ContextSwitchMode.SUBSTMT:
if self.pending_query is not None:
self.rel = self.pending_query
Expand Down
8 changes: 8 additions & 0 deletions edb/pgsql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def _compile_set_impl(
and not ctx.env.ignore_object_shapes):
_compile_shape(ir_set, ir_set.shape, ctx=ctx)

elif isinstance(ir_set.expr, irast.InlinedParameter):
# InlinedParameter will already be compiled in process_set_as_func_expr,
# Just place the path value here.
value = ctx.inlined_params[ir_set.expr.name]
pathctx.put_path_value_var_if_not_exists(
ctx.rel, ir_set.path_id, value
)

elif ir_set.path_scope_id is not None and not is_toplevel:
# This Set is behind a scope fence, so compute it
# in a fenced context.
Expand Down
22 changes: 15 additions & 7 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,7 @@ def process_set_as_oper_expr(
# XXX: do we need a subrel?
with ctx.new() as newctx:
newctx.expr_exposed = False
args = _compile_call_args(ir_set, ctx=newctx)
args, _ = _compile_call_args(ir_set, ctx=newctx)
oper_expr = exprcomp.compile_operator(ir_set.expr, args, ctx=newctx)

pathctx.put_path_value_var_if_not_exists(
Expand Down Expand Up @@ -3355,7 +3355,7 @@ def _compile_call_args(
skip: Collection[int] = (),
no_subquery_args: bool = False,
ctx: context.CompilerContextLevel,
) -> List[pgast.BaseExpr]:
) -> tuple[list[pgast.BaseExpr], dict[int | str, int]]:
"""
Compiles function call arguments, whose index is not in `skip`.
"""
Expand All @@ -3364,6 +3364,7 @@ def _compile_call_args(
assert isinstance(expr, irast.Call)

args = []
arg_indexes = {}

if isinstance(expr, irast.FunctionCall) and expr.global_args:
for glob_arg in expr.global_args:
Expand Down Expand Up @@ -3399,6 +3400,7 @@ def _compile_call_args(
else:
arg_ref = dispatch.compile(ir_arg.expr, ctx=ctx)
arg_ref = output.output_as_value(arg_ref, env=ctx.env)
arg_indexes[ir_key] = len(args)
args.append(arg_ref)
_compile_arg_null_check(expr, ir_arg, arg_ref, typemod, ctx=ctx)

Expand Down Expand Up @@ -3434,7 +3436,7 @@ def _compile_call_args(

args.append(pgast.VariadicArgument(expr=var))

return args
return args, arg_indexes


def process_set_as_func_enumerate(
Expand All @@ -3450,7 +3452,7 @@ def process_set_as_func_enumerate(
with ctx.subrel() as newctx:
with newctx.new() as newctx2:
newctx2.expr_exposed = False
args = _compile_call_args(inner_func_set, ctx=newctx2)
args, _ = _compile_call_args(inner_func_set, ctx=newctx2)
func_name = exprcomp.get_func_call_backend_name(inner_func, ctx=newctx)

set_expr = _process_set_func_with_ordinality(
Expand All @@ -3475,10 +3477,16 @@ def process_set_as_func_expr(

with ctx.subrel() as newctx:
newctx.expr_exposed = False
args = _compile_call_args(ir_set, ctx=newctx)
args, arg_indexes = _compile_call_args(ir_set, ctx=newctx)

if expr.body is not None:
set_expr = dispatch.compile(expr.body, ctx=newctx)
with newctx.subrel() as inlined_ctx:
inlined_ctx.inlined_params = {
name: args[index]
for name, index in arg_indexes.items()
}
set_expr = dispatch.compile(expr.body, ctx=inlined_ctx)

else:
name = exprcomp.get_func_call_backend_name(expr, ctx=newctx)

Expand Down Expand Up @@ -4148,7 +4156,7 @@ def _process_set_as_object_search(
# Also, disable subquery args. ai::search needs it for its
# scoping effects, but we don't need to use it here, since
# it can cause the ai search to duplicate arguments.
args_pg = _compile_call_args(
args_pg, _ = _compile_call_args(
ir_set, skip={0}, no_subquery_args=True, ctx=ctx)

with ctx.subrel() as newctx:
Expand Down

0 comments on commit fe66209

Please sign in to comment.