Skip to content

Commit

Permalink
Version calls to stdlib sql functions (#7528)
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan authored Jul 3, 2024
1 parent 6acf5d5 commit d14e0aa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 25 deletions.
6 changes: 6 additions & 0 deletions edb/pgsql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def versioned_name(
return s


def maybe_versioned_name(
s: tuple[str, ...], *, versioned: bool,
) -> tuple[str, ...]:
return versioned_name(s) if versioned else s


@functools.lru_cache()
def _edgedb_name_to_pg_name(name: str, prefix_length: int = 0) -> str:
# Note: PostgreSQL doesn't have a sha1 implementation as a
Expand Down
42 changes: 35 additions & 7 deletions edb/pgsql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,19 @@ def compile_operator(

else:
if expr.sql_function:
sql_func = expr.sql_function[0]
func_name = tuple(sql_func.split('.', 1))
if len(expr.sql_function) > 1:
sql_func, *cast_types = expr.sql_function

func_name = common.maybe_versioned_name(
tuple(sql_func.split('.', 1)),
versioned=(
ctx.env.versioned_stdlib
and expr.func_shortname.get_root_module_name().name != 'ext'
),
)

if cast_types:
# Explicit operand types given in FROM SQL FUNCTION
lexpr, rexpr = _cast_operands(
lexpr, rexpr, expr.sql_function[1:])
lexpr, rexpr = _cast_operands(lexpr, rexpr, cast_types)
else:
func_name = common.get_operator_backend_name(
expr.func_shortname, aspect='function',
Expand Down Expand Up @@ -541,7 +548,7 @@ def compile_operator(
def _cast_operands(
lexpr: Optional[pgast.BaseExpr],
rexpr: Optional[pgast.BaseExpr],
sql_types: Tuple[str, ...],
sql_types: Sequence[str],
) -> Tuple[Optional[pgast.BaseExpr], Optional[pgast.BaseExpr]]:

if lexpr is not None:
Expand Down Expand Up @@ -584,6 +591,27 @@ def _cast_operands(
return lexpr, rexpr


def get_func_call_backend_name(
expr: irast.FunctionCall, *,
ctx: context.CompilerContextLevel
) -> Tuple[str, ...]:
if expr.func_sql_function:
# The name might contain a "." if it's one of our
# metaschema helpers.
func_name = common.maybe_versioned_name(
tuple(expr.func_sql_function.split('.', 1)),
versioned=(
ctx.env.versioned_stdlib
and expr.func_shortname.get_root_module_name().name != 'ext'
),
)
else:
func_name = common.get_function_backend_name(
expr.func_shortname, expr.backend_name,
versioned=ctx.env.versioned_stdlib)
return func_name


@dispatch.compile.register(irast.TypeCheckOp)
def compile_TypeCheckOp(
expr: irast.TypeCheckOp, *,
Expand Down Expand Up @@ -717,7 +745,7 @@ def compile_FunctionCall(

args.append(pgast.VariadicArgument(expr=var))

name = relgen.get_func_call_backend_name(expr, ctx=ctx)
name = get_func_call_backend_name(expr, ctx=ctx)

result: pgast.BaseExpr = pgast.FuncCall(name=name, args=args)

Expand Down
21 changes: 3 additions & 18 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3304,21 +3304,6 @@ def _compile_call_args(
return args


def get_func_call_backend_name(
expr: irast.FunctionCall, *,
ctx: context.CompilerContextLevel) -> Tuple[str, ...]:
if expr.func_sql_function:
# The name might contain a "." if it's one of our
# metaschema helpers.
# XXX: VERSIONING?
func_name = tuple(expr.func_sql_function.split('.', 1))
else:
func_name = common.get_function_backend_name(
expr.func_shortname, expr.backend_name,
versioned=ctx.env.versioned_stdlib)
return func_name


def process_set_as_func_enumerate(
ir_set: irast.Set, *, ctx: context.CompilerContextLevel
) -> SetRVars:
Expand All @@ -3333,7 +3318,7 @@ def process_set_as_func_enumerate(
with newctx.new() as newctx2:
newctx2.expr_exposed = False
args = _compile_call_args(inner_func_set, ctx=newctx2)
func_name = get_func_call_backend_name(inner_func, ctx=newctx)
func_name = exprcomp.get_func_call_backend_name(inner_func, ctx=newctx)

set_expr = _process_set_func_with_ordinality(
ir_set=inner_func_set,
Expand Down Expand Up @@ -3362,7 +3347,7 @@ def process_set_as_func_expr(
if expr.body is not None:
set_expr = dispatch.compile(expr.body, ctx=newctx)
else:
name = get_func_call_backend_name(expr, ctx=newctx)
name = exprcomp.get_func_call_backend_name(expr, ctx=newctx)

if expr.typemod is qltypes.TypeModifier.SetOfType:
set_expr = _process_set_func(
Expand Down Expand Up @@ -3539,7 +3524,7 @@ def process_set_as_agg_expr_inner(

args.append(arg_ref)

name = get_func_call_backend_name(expr, ctx=newctx)
name = exprcomp.get_func_call_backend_name(expr, ctx=newctx)

set_expr = pgast.FuncCall(
name=name, args=args, agg_order=agg_sort, agg_filter=agg_filter,
Expand Down

0 comments on commit d14e0aa

Please sign in to comment.