From f3074a595b9fbd2373d2d3264eb83703a4ac5e55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Wed, 16 Oct 2024 15:22:09 +0200 Subject: [PATCH 1/2] don't reorder SQL params --- edb/pgsql/resolver/__init__.py | 1 + edb/pgsql/resolver/command.py | 20 +++++++++++++- edb/pgsql/resolver/expr.py | 17 ++---------- edb/server/compiler/dbstate.py | 4 +-- edb/server/pgcon/pgcon.pyx | 20 ++++++-------- edb/server/protocol/pg_ext.pyx | 48 ++++++++++++++++++---------------- 6 files changed, 58 insertions(+), 52 deletions(-) diff --git a/edb/pgsql/resolver/__init__.py b/edb/pgsql/resolver/__init__.py index 17ff75bffaf..ddf37165c15 100644 --- a/edb/pgsql/resolver/__init__.py +++ b/edb/pgsql/resolver/__init__.py @@ -67,6 +67,7 @@ def resolve( _ = context.ResolverContext(initial=ctx) + command.init_external_params(query, ctx) top_level_ctes = command.compile_dml(query, ctx=ctx) resolved = dispatch.resolve(query, ctx=ctx) diff --git a/edb/pgsql/resolver/command.py b/edb/pgsql/resolver/command.py index dec4cd7bcbc..e4fe78f5c89 100644 --- a/edb/pgsql/resolver/command.py +++ b/edb/pgsql/resolver/command.py @@ -727,7 +727,8 @@ def _uncompile_insert_pointer_stmt( sub_name = sub.get_shortname(ctx.schema) target_ql: qlast.Expr = qlast.Path( - steps=[value_ql, qlast.Ptr(name='__target__')]) + steps=[value_ql, qlast.Ptr(name='__target__')] + ) if isinstance(sub_target, s_objtypes.ObjectType): assert isinstance(target_ql, qlast.Path) @@ -2093,3 +2094,20 @@ def __init__(self, mapping: Dict[int, int]) -> None: def visit_Param(self, p: pgast.Param) -> None: p.index = self.mapping[p.index] + + +def init_external_params(query: pgast.Base, ctx: Context): + counter = ParamCounter() + counter.node_visit(query) + for _ in range(counter.param_count): + ctx.query_params.append(dbstate.SQLParamExternal()) + + +class ParamCounter(ast.NodeVisitor): + def __init__(self) -> None: + super().__init__() + self.param_count = 0 + + def visit_ParamRef(self, p: pgast.ParamRef) -> None: + if self.param_count < p.number: + self.param_count = p.number diff --git a/edb/pgsql/resolver/expr.py b/edb/pgsql/resolver/expr.py index 08275b64fcb..ccca55c479f 100644 --- a/edb/pgsql/resolver/expr.py +++ b/edb/pgsql/resolver/expr.py @@ -29,7 +29,6 @@ from edb.pgsql import common from edb.pgsql import compiler as pgcompiler from edb.pgsql.compiler import enums as pgce -from edb.server.compiler import dbstate from edb.schema import types as s_types @@ -513,20 +512,8 @@ def resolve_ParamRef( *, ctx: Context, ) -> pgast.ParamRef: - internal_index: Optional[int] = None - param: Optional[dbstate.SQLParam] = None - for i, p in enumerate(ctx.query_params): - if isinstance(p, dbstate.SQLParamExternal) and p.index == expr.number: - param = p - internal_index = i + 1 - break - if not param: - param = dbstate.SQLParamExternal(index=expr.number) - internal_index = len(ctx.query_params) + 1 - ctx.query_params.append(param) - assert internal_index - - return pgast.ParamRef(number=internal_index) + # external params map one-to-one to internal params + return expr @dispatch._resolve.register diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index f801f882a71..68e38e9b1b1 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -600,8 +600,8 @@ class SQLParamExternal(SQLParam): # An internal query param whose value is provided by an external param. # So a user-visible param. - # External index - index: int + # External params share the index with internal params + pass @dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False) diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 18e08fec2af..e01d5eae299 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1782,23 +1782,19 @@ cdef class PGConnection: # include the internal params for globals. # This chunk of code remaps the descriptions of internal # params into external ones. - count_internal = self.buffer.read_int16() + self.buffer.read_int16() # count_internal data_internal = self.buffer.consume_message() msg_buf = WriteBuffer.new_message(b't') - external_params = [] + external_params: int64_t = 0 if action.query_unit.params: - for i_int, param in enumerate(action.query_unit.params): - if isinstance(param, dbstate.SQLParamExternal): - i_ext = param.index - 1 - external_params.append((i_ext, i_int)) + for index, param in enumerate(action.query_unit.params): + external_params = index + 1 + if not isinstance(param, dbstate.SQLParamExternal): + break - msg_buf.write_int16(len(external_params)) - - external_params.sort() - for _, i_int in external_params: - oid = data_internal[i_int * 4:(i_int + 1) * 4] - msg_buf.write_bytes(oid) + msg_buf.write_int16(external_params) + msg_buf.write_bytes(data_internal[0:external_params * 4]) buf.write_buffer(msg_buf.end_message()) diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index 3a8362d6560..e1614f7f8fb 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -1585,7 +1585,8 @@ cdef WriteBuffer remap_arguments( cdef: int16_t param_format_count int32_t offset - int16_t max_external_used + int32_t arg_offset_external + int16_t param_count_external int32_t size # The "external" parameters (that are visible to the user) @@ -1621,30 +1622,40 @@ cdef WriteBuffer remap_arguments( offset += param_format_count * 2 # find positions of external args - param_count_external = read_int16(data[offset:offset+2]) + arg_count_external = read_int16(data[offset:offset+2]) offset += 2 - param_pos_external = [] - for p in range(param_count_external): + arg_offset_external = offset + for p in range(arg_count_external): size = read_int32(data[offset:offset+4]) if size == -1: # special case: NULL size = 0 size += 4 # for size which is int32 - param_pos_external.append((offset, size)) offset += size # write remapped args - max_external_used = 0 if params: buf.write_int16(len(params)) - for param in params: - if isinstance(param, dbstate.SQLParamExternal): - # map external arg to internal - o, s = param_pos_external[param.index - 1] - buf.write_bytes(data[o:o+s]) - if max_external_used < param.index: - max_external_used = param.index - elif isinstance(param, dbstate.SQLParamGlobal): + param_count_external = 0 + for i, param in enumerate(params): + param_count_external = i + 1 + if not isinstance(param, dbstate.SQLParamExternal): + break + if param_count_external != arg_count_external: + raise pgerror.new( + pgerror.ERROR_PROTOCOL_VIOLATION, + f'bind message supplies {arg_count_external} ' + f'parameters, but prepared statement "" requires ' + f'{param_count_external}', + ) + + # write external args + if arg_offset_external < offset: + buf.write_bytes(data[arg_offset_external:offset]) + + # write global's args + for param in params[param_count_external:]: + if isinstance(param, dbstate.SQLParamGlobal): name = param.global_name setting_name = f'global {name.module}::{name.name}' values = fe_settings.get(setting_name, None) @@ -1656,14 +1667,7 @@ cdef WriteBuffer remap_arguments( else: buf.write_int16(0) - if max_external_used != param_count_external: - raise pgerror.new( - pgerror.ERROR_PROTOCOL_VIOLATION, - f'bind message supplies {param_count_external} ' - f'parameters, but prepared statement "" requires ' - f'{max_external_used}', - ) - + # result format codes buf.write_bytes(data[offset:]) return buf From 9828e5599337c403cef218416db6c508a7caae79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Wed, 16 Oct 2024 19:01:58 +0200 Subject: [PATCH 2/2] fix --- edb/server/pgcon/pgcon.pyx | 2 +- edb/server/protocol/pg_ext.pyx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index e01d5eae299..6e5c17fe27c 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1789,9 +1789,9 @@ cdef class PGConnection: external_params: int64_t = 0 if action.query_unit.params: for index, param in enumerate(action.query_unit.params): - external_params = index + 1 if not isinstance(param, dbstate.SQLParamExternal): break + external_params = index + 1 msg_buf.write_int16(external_params) msg_buf.write_bytes(data_internal[0:external_params * 4]) diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index e1614f7f8fb..7b4c6b1e50b 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -1638,9 +1638,9 @@ cdef WriteBuffer remap_arguments( param_count_external = 0 for i, param in enumerate(params): - param_count_external = i + 1 if not isinstance(param, dbstate.SQLParamExternal): break + param_count_external = i + 1 if param_count_external != arg_count_external: raise pgerror.new( pgerror.ERROR_PROTOCOL_VIOLATION,