Skip to content

Commit

Permalink
Persist SQL param indexes over SQL adapter (#7867)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Oct 17, 2024
1 parent 68a32d3 commit 22cab5c
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 52 deletions.
1 change: 1 addition & 0 deletions edb/pgsql/resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def resolve(

_ = context.ResolverContext(initial=ctx)

command.init_external_params(query, ctx)
top_level_ctes = command.compile_dml(query, ctx=ctx)

resolved = dispatch.resolve(query, ctx=ctx)
Expand Down
20 changes: 19 additions & 1 deletion edb/pgsql/resolver/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,8 @@ def _uncompile_insert_pointer_stmt(
sub_name = sub.get_shortname(ctx.schema)

target_ql: qlast.Expr = qlast.Path(
steps=[value_ql, qlast.Ptr(name='__target__')])
steps=[value_ql, qlast.Ptr(name='__target__')]
)

if isinstance(sub_target, s_objtypes.ObjectType):
assert isinstance(target_ql, qlast.Path)
Expand Down Expand Up @@ -2094,3 +2095,20 @@ def __init__(self, mapping: Dict[int, int]) -> None:

def visit_Param(self, p: pgast.Param) -> None:
p.index = self.mapping[p.index]


def init_external_params(query: pgast.Base, ctx: Context):
counter = ParamCounter()
counter.node_visit(query)
for _ in range(counter.param_count):
ctx.query_params.append(dbstate.SQLParamExternal())


class ParamCounter(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.param_count = 0

def visit_ParamRef(self, p: pgast.ParamRef) -> None:
if self.param_count < p.number:
self.param_count = p.number
17 changes: 2 additions & 15 deletions edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from edb.pgsql import common
from edb.pgsql import compiler as pgcompiler
from edb.pgsql.compiler import enums as pgce
from edb.server.compiler import dbstate

from edb.schema import types as s_types

Expand Down Expand Up @@ -515,20 +514,8 @@ def resolve_ParamRef(
*,
ctx: Context,
) -> pgast.ParamRef:
internal_index: Optional[int] = None
param: Optional[dbstate.SQLParam] = None
for i, p in enumerate(ctx.query_params):
if isinstance(p, dbstate.SQLParamExternal) and p.index == expr.number:
param = p
internal_index = i + 1
break
if not param:
param = dbstate.SQLParamExternal(index=expr.number)
internal_index = len(ctx.query_params) + 1
ctx.query_params.append(param)
assert internal_index

return pgast.ParamRef(number=internal_index)
# external params map one-to-one to internal params
return expr


@dispatch._resolve.register
Expand Down
4 changes: 2 additions & 2 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ class SQLParamExternal(SQLParam):
# An internal query param whose value is provided by an external param.
# So a user-visible param.

# External index
index: int
# External params share the index with internal params
pass


@dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False)
Expand Down
20 changes: 8 additions & 12 deletions edb/server/pgcon/pgcon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1782,23 +1782,19 @@ cdef class PGConnection:
# include the internal params for globals.
# This chunk of code remaps the descriptions of internal
# params into external ones.
count_internal = self.buffer.read_int16()
self.buffer.read_int16() # count_internal
data_internal = self.buffer.consume_message()

msg_buf = WriteBuffer.new_message(b't')
external_params = []
external_params: int64_t = 0
if action.query_unit.params:
for i_int, param in enumerate(action.query_unit.params):
if isinstance(param, dbstate.SQLParamExternal):
i_ext = param.index - 1
external_params.append((i_ext, i_int))
for index, param in enumerate(action.query_unit.params):
if not isinstance(param, dbstate.SQLParamExternal):
break
external_params = index + 1

msg_buf.write_int16(len(external_params))

external_params.sort()
for _, i_int in external_params:
oid = data_internal[i_int * 4:(i_int + 1) * 4]
msg_buf.write_bytes(oid)
msg_buf.write_int16(external_params)
msg_buf.write_bytes(data_internal[0:external_params * 4])

buf.write_buffer(msg_buf.end_message())

Expand Down
48 changes: 26 additions & 22 deletions edb/server/protocol/pg_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,8 @@ cdef WriteBuffer remap_arguments(
cdef:
int16_t param_format_count
int32_t offset
int16_t max_external_used
int32_t arg_offset_external
int16_t param_count_external
int32_t size

# The "external" parameters (that are visible to the user)
Expand Down Expand Up @@ -1621,30 +1622,40 @@ cdef WriteBuffer remap_arguments(
offset += param_format_count * 2

# find positions of external args
param_count_external = read_int16(data[offset:offset+2])
arg_count_external = read_int16(data[offset:offset+2])
offset += 2
param_pos_external = []
for p in range(param_count_external):
arg_offset_external = offset
for p in range(arg_count_external):
size = read_int32(data[offset:offset+4])
if size == -1: # special case: NULL
size = 0
size += 4 # for size which is int32
param_pos_external.append((offset, size))
offset += size

# write remapped args
max_external_used = 0
if params:
buf.write_int16(len(params))
for param in params:
if isinstance(param, dbstate.SQLParamExternal):
# map external arg to internal
o, s = param_pos_external[param.index - 1]
buf.write_bytes(data[o:o+s])

if max_external_used < param.index:
max_external_used = param.index
elif isinstance(param, dbstate.SQLParamGlobal):
param_count_external = 0
for i, param in enumerate(params):
if not isinstance(param, dbstate.SQLParamExternal):
break
param_count_external = i + 1
if param_count_external != arg_count_external:
raise pgerror.new(
pgerror.ERROR_PROTOCOL_VIOLATION,
f'bind message supplies {arg_count_external} '
f'parameters, but prepared statement "" requires '
f'{param_count_external}',
)

# write external args
if arg_offset_external < offset:
buf.write_bytes(data[arg_offset_external:offset])

# write global's args
for param in params[param_count_external:]:
if isinstance(param, dbstate.SQLParamGlobal):
name = param.global_name
setting_name = f'global {name.module}::{name.name}'
values = fe_settings.get(setting_name, None)
Expand All @@ -1656,14 +1667,7 @@ cdef WriteBuffer remap_arguments(
else:
buf.write_int16(0)

if max_external_used != param_count_external:
raise pgerror.new(
pgerror.ERROR_PROTOCOL_VIOLATION,
f'bind message supplies {param_count_external} '
f'parameters, but prepared statement "" requires '
f'{max_external_used}',
)

# result format codes
buf.write_bytes(data[offset:])
return buf

Expand Down

0 comments on commit 22cab5c

Please sign in to comment.