diff --git a/edb/pgsql/common.py b/edb/pgsql/common.py index 38d1b88fe3e..9e991297318 100644 --- a/edb/pgsql/common.py +++ b/edb/pgsql/common.py @@ -192,6 +192,12 @@ def versioned_name( return s +def maybe_versioned_name( + s: tuple[str, ...], *, versioned: bool, +) -> tuple[str, ...]: + return versioned_name(s) if versioned else s + + @functools.lru_cache() def _edgedb_name_to_pg_name(name: str, prefix_length: int = 0) -> str: # Note: PostgreSQL doesn't have a sha1 implementation as a diff --git a/edb/pgsql/compiler/expr.py b/edb/pgsql/compiler/expr.py index d16e7a859c0..3525c2d0855 100644 --- a/edb/pgsql/compiler/expr.py +++ b/edb/pgsql/compiler/expr.py @@ -497,12 +497,19 @@ def compile_operator( else: if expr.sql_function: - sql_func = expr.sql_function[0] - func_name = tuple(sql_func.split('.', 1)) - if len(expr.sql_function) > 1: + sql_func, *cast_types = expr.sql_function + + func_name = common.maybe_versioned_name( + tuple(sql_func.split('.', 1)), + versioned=( + ctx.env.versioned_stdlib + and expr.func_shortname.get_root_module_name().name != 'ext' + ), + ) + + if cast_types: # Explicit operand types given in FROM SQL FUNCTION - lexpr, rexpr = _cast_operands( - lexpr, rexpr, expr.sql_function[1:]) + lexpr, rexpr = _cast_operands(lexpr, rexpr, cast_types) else: func_name = common.get_operator_backend_name( expr.func_shortname, aspect='function', @@ -541,7 +548,7 @@ def compile_operator( def _cast_operands( lexpr: Optional[pgast.BaseExpr], rexpr: Optional[pgast.BaseExpr], - sql_types: Tuple[str, ...], + sql_types: Sequence[str], ) -> Tuple[Optional[pgast.BaseExpr], Optional[pgast.BaseExpr]]: if lexpr is not None: @@ -584,6 +591,27 @@ def _cast_operands( return lexpr, rexpr +def get_func_call_backend_name( + expr: irast.FunctionCall, *, + ctx: context.CompilerContextLevel +) -> Tuple[str, ...]: + if expr.func_sql_function: + # The name might contain a "." if it's one of our + # metaschema helpers. + func_name = common.maybe_versioned_name( + tuple(expr.func_sql_function.split('.', 1)), + versioned=( + ctx.env.versioned_stdlib + and expr.func_shortname.get_root_module_name().name != 'ext' + ), + ) + else: + func_name = common.get_function_backend_name( + expr.func_shortname, expr.backend_name, + versioned=ctx.env.versioned_stdlib) + return func_name + + @dispatch.compile.register(irast.TypeCheckOp) def compile_TypeCheckOp( expr: irast.TypeCheckOp, *, @@ -717,7 +745,7 @@ def compile_FunctionCall( args.append(pgast.VariadicArgument(expr=var)) - name = relgen.get_func_call_backend_name(expr, ctx=ctx) + name = get_func_call_backend_name(expr, ctx=ctx) result: pgast.BaseExpr = pgast.FuncCall(name=name, args=args) diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index c3ccee95ec5..8033d8c8438 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -3304,21 +3304,6 @@ def _compile_call_args( return args -def get_func_call_backend_name( - expr: irast.FunctionCall, *, - ctx: context.CompilerContextLevel) -> Tuple[str, ...]: - if expr.func_sql_function: - # The name might contain a "." if it's one of our - # metaschema helpers. - # XXX: VERSIONING? - func_name = tuple(expr.func_sql_function.split('.', 1)) - else: - func_name = common.get_function_backend_name( - expr.func_shortname, expr.backend_name, - versioned=ctx.env.versioned_stdlib) - return func_name - - def process_set_as_func_enumerate( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> SetRVars: @@ -3333,7 +3318,7 @@ def process_set_as_func_enumerate( with newctx.new() as newctx2: newctx2.expr_exposed = False args = _compile_call_args(inner_func_set, ctx=newctx2) - func_name = get_func_call_backend_name(inner_func, ctx=newctx) + func_name = exprcomp.get_func_call_backend_name(inner_func, ctx=newctx) set_expr = _process_set_func_with_ordinality( ir_set=inner_func_set, @@ -3362,7 +3347,7 @@ def process_set_as_func_expr( if expr.body is not None: set_expr = dispatch.compile(expr.body, ctx=newctx) else: - name = get_func_call_backend_name(expr, ctx=newctx) + name = exprcomp.get_func_call_backend_name(expr, ctx=newctx) if expr.typemod is qltypes.TypeModifier.SetOfType: set_expr = _process_set_func( @@ -3539,7 +3524,7 @@ def process_set_as_agg_expr_inner( args.append(arg_ref) - name = get_func_call_backend_name(expr, ctx=newctx) + name = exprcomp.get_func_call_backend_name(expr, ctx=newctx) set_expr = pgast.FuncCall( name=name, args=args, agg_order=agg_sort, agg_filter=agg_filter,