diff --git a/edb/pgsql/resolver/__init__.py b/edb/pgsql/resolver/__init__.py index 17ff75bffaf..ddf37165c15 100644 --- a/edb/pgsql/resolver/__init__.py +++ b/edb/pgsql/resolver/__init__.py @@ -67,6 +67,7 @@ def resolve( _ = context.ResolverContext(initial=ctx) + command.init_external_params(query, ctx) top_level_ctes = command.compile_dml(query, ctx=ctx) resolved = dispatch.resolve(query, ctx=ctx) diff --git a/edb/pgsql/resolver/command.py b/edb/pgsql/resolver/command.py index 89ab5a50313..1e9ad54c2e7 100644 --- a/edb/pgsql/resolver/command.py +++ b/edb/pgsql/resolver/command.py @@ -727,7 +727,8 @@ def _uncompile_insert_pointer_stmt( sub_name = sub.get_shortname(ctx.schema) target_ql: qlast.Expr = qlast.Path( - steps=[value_ql, qlast.Ptr(name='__target__')]) + steps=[value_ql, qlast.Ptr(name='__target__')] + ) if isinstance(sub_target, s_objtypes.ObjectType): assert isinstance(target_ql, qlast.Path) @@ -2094,3 +2095,20 @@ def __init__(self, mapping: Dict[int, int]) -> None: def visit_Param(self, p: pgast.Param) -> None: p.index = self.mapping[p.index] + + +def init_external_params(query: pgast.Base, ctx: Context): + counter = ParamCounter() + counter.node_visit(query) + for _ in range(counter.param_count): + ctx.query_params.append(dbstate.SQLParamExternal()) + + +class ParamCounter(ast.NodeVisitor): + def __init__(self) -> None: + super().__init__() + self.param_count = 0 + + def visit_ParamRef(self, p: pgast.ParamRef) -> None: + if self.param_count < p.number: + self.param_count = p.number diff --git a/edb/pgsql/resolver/expr.py b/edb/pgsql/resolver/expr.py index 51a8638632c..264caf86cbd 100644 --- a/edb/pgsql/resolver/expr.py +++ b/edb/pgsql/resolver/expr.py @@ -29,7 +29,6 @@ from edb.pgsql import common from edb.pgsql import compiler as pgcompiler from edb.pgsql.compiler import enums as pgce -from edb.server.compiler import dbstate from edb.schema import types as s_types @@ -515,20 +514,8 @@ def resolve_ParamRef( *, ctx: Context, ) -> pgast.ParamRef: - internal_index: Optional[int] = None - param: Optional[dbstate.SQLParam] = None - for i, p in enumerate(ctx.query_params): - if isinstance(p, dbstate.SQLParamExternal) and p.index == expr.number: - param = p - internal_index = i + 1 - break - if not param: - param = dbstate.SQLParamExternal(index=expr.number) - internal_index = len(ctx.query_params) + 1 - ctx.query_params.append(param) - assert internal_index - - return pgast.ParamRef(number=internal_index) + # external params map one-to-one to internal params + return expr @dispatch._resolve.register diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 0997f83731d..8b3cebb5e48 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -600,8 +600,8 @@ class SQLParamExternal(SQLParam): # An internal query param whose value is provided by an external param. # So a user-visible param. - # External index - index: int + # External params share the index with internal params + pass @dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False) diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 169116e38c8..31f9dd62664 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1782,23 +1782,19 @@ cdef class PGConnection: # include the internal params for globals. # This chunk of code remaps the descriptions of internal # params into external ones. - count_internal = self.buffer.read_int16() + self.buffer.read_int16() # count_internal data_internal = self.buffer.consume_message() msg_buf = WriteBuffer.new_message(b't') - external_params = [] + external_params: int64_t = 0 if action.query_unit.params: - for i_int, param in enumerate(action.query_unit.params): - if isinstance(param, dbstate.SQLParamExternal): - i_ext = param.index - 1 - external_params.append((i_ext, i_int)) + for index, param in enumerate(action.query_unit.params): + if not isinstance(param, dbstate.SQLParamExternal): + break + external_params = index + 1 - msg_buf.write_int16(len(external_params)) - - external_params.sort() - for _, i_int in external_params: - oid = data_internal[i_int * 4:(i_int + 1) * 4] - msg_buf.write_bytes(oid) + msg_buf.write_int16(external_params) + msg_buf.write_bytes(data_internal[0:external_params * 4]) buf.write_buffer(msg_buf.end_message()) diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index 3a8362d6560..7b4c6b1e50b 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -1585,7 +1585,8 @@ cdef WriteBuffer remap_arguments( cdef: int16_t param_format_count int32_t offset - int16_t max_external_used + int32_t arg_offset_external + int16_t param_count_external int32_t size # The "external" parameters (that are visible to the user) @@ -1621,30 +1622,40 @@ cdef WriteBuffer remap_arguments( offset += param_format_count * 2 # find positions of external args - param_count_external = read_int16(data[offset:offset+2]) + arg_count_external = read_int16(data[offset:offset+2]) offset += 2 - param_pos_external = [] - for p in range(param_count_external): + arg_offset_external = offset + for p in range(arg_count_external): size = read_int32(data[offset:offset+4]) if size == -1: # special case: NULL size = 0 size += 4 # for size which is int32 - param_pos_external.append((offset, size)) offset += size # write remapped args - max_external_used = 0 if params: buf.write_int16(len(params)) - for param in params: - if isinstance(param, dbstate.SQLParamExternal): - # map external arg to internal - o, s = param_pos_external[param.index - 1] - buf.write_bytes(data[o:o+s]) - if max_external_used < param.index: - max_external_used = param.index - elif isinstance(param, dbstate.SQLParamGlobal): + param_count_external = 0 + for i, param in enumerate(params): + if not isinstance(param, dbstate.SQLParamExternal): + break + param_count_external = i + 1 + if param_count_external != arg_count_external: + raise pgerror.new( + pgerror.ERROR_PROTOCOL_VIOLATION, + f'bind message supplies {arg_count_external} ' + f'parameters, but prepared statement "" requires ' + f'{param_count_external}', + ) + + # write external args + if arg_offset_external < offset: + buf.write_bytes(data[arg_offset_external:offset]) + + # write global's args + for param in params[param_count_external:]: + if isinstance(param, dbstate.SQLParamGlobal): name = param.global_name setting_name = f'global {name.module}::{name.name}' values = fe_settings.get(setting_name, None) @@ -1656,14 +1667,7 @@ cdef WriteBuffer remap_arguments( else: buf.write_int16(0) - if max_external_used != param_count_external: - raise pgerror.new( - pgerror.ERROR_PROTOCOL_VIOLATION, - f'bind message supplies {param_count_external} ' - f'parameters, but prepared statement "" requires ' - f'{max_external_used}', - ) - + # result format codes buf.write_bytes(data[offset:]) return buf