Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persist SQL param indexes over SQL adapter #7867

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions edb/pgsql/resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def resolve(

_ = context.ResolverContext(initial=ctx)

command.init_external_params(query, ctx)
top_level_ctes = command.compile_dml(query, ctx=ctx)

resolved = dispatch.resolve(query, ctx=ctx)
Expand Down
20 changes: 19 additions & 1 deletion edb/pgsql/resolver/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,8 @@ def _uncompile_insert_pointer_stmt(
sub_name = sub.get_shortname(ctx.schema)

target_ql: qlast.Expr = qlast.Path(
steps=[value_ql, qlast.Ptr(name='__target__')])
steps=[value_ql, qlast.Ptr(name='__target__')]
)

if isinstance(sub_target, s_objtypes.ObjectType):
assert isinstance(target_ql, qlast.Path)
Expand Down Expand Up @@ -2093,3 +2094,20 @@ def __init__(self, mapping: Dict[int, int]) -> None:

def visit_Param(self, p: pgast.Param) -> None:
p.index = self.mapping[p.index]


def init_external_params(query: pgast.Base, ctx: Context):
counter = ParamCounter()
counter.node_visit(query)
for _ in range(counter.param_count):
ctx.query_params.append(dbstate.SQLParamExternal())


class ParamCounter(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.param_count = 0

def visit_ParamRef(self, p: pgast.ParamRef) -> None:
if self.param_count < p.number:
self.param_count = p.number
17 changes: 2 additions & 15 deletions edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from edb.pgsql import common
from edb.pgsql import compiler as pgcompiler
from edb.pgsql.compiler import enums as pgce
from edb.server.compiler import dbstate

from edb.schema import types as s_types

Expand Down Expand Up @@ -513,20 +512,8 @@ def resolve_ParamRef(
*,
ctx: Context,
) -> pgast.ParamRef:
internal_index: Optional[int] = None
param: Optional[dbstate.SQLParam] = None
for i, p in enumerate(ctx.query_params):
if isinstance(p, dbstate.SQLParamExternal) and p.index == expr.number:
param = p
internal_index = i + 1
break
if not param:
param = dbstate.SQLParamExternal(index=expr.number)
internal_index = len(ctx.query_params) + 1
ctx.query_params.append(param)
assert internal_index

return pgast.ParamRef(number=internal_index)
# external params map one-to-one to internal params
return expr


@dispatch._resolve.register
Expand Down
4 changes: 2 additions & 2 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ class SQLParamExternal(SQLParam):
# An internal query param whose value is provided by an external param.
# So a user-visible param.

# External index
index: int
# External params share the index with internal params
pass


@dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False)
Expand Down
20 changes: 8 additions & 12 deletions edb/server/pgcon/pgcon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1782,23 +1782,19 @@ cdef class PGConnection:
# include the internal params for globals.
# This chunk of code remaps the descriptions of internal
# params into external ones.
count_internal = self.buffer.read_int16()
self.buffer.read_int16() # count_internal
data_internal = self.buffer.consume_message()

msg_buf = WriteBuffer.new_message(b't')
external_params = []
external_params: int64_t = 0
if action.query_unit.params:
for i_int, param in enumerate(action.query_unit.params):
if isinstance(param, dbstate.SQLParamExternal):
i_ext = param.index - 1
external_params.append((i_ext, i_int))
for index, param in enumerate(action.query_unit.params):
if not isinstance(param, dbstate.SQLParamExternal):
break
external_params = index + 1

msg_buf.write_int16(len(external_params))

external_params.sort()
for _, i_int in external_params:
oid = data_internal[i_int * 4:(i_int + 1) * 4]
msg_buf.write_bytes(oid)
msg_buf.write_int16(external_params)
msg_buf.write_bytes(data_internal[0:external_params * 4])

buf.write_buffer(msg_buf.end_message())

Expand Down
48 changes: 26 additions & 22 deletions edb/server/protocol/pg_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,8 @@ cdef WriteBuffer remap_arguments(
cdef:
int16_t param_format_count
int32_t offset
int16_t max_external_used
int32_t arg_offset_external
int16_t param_count_external
int32_t size

# The "external" parameters (that are visible to the user)
Expand Down Expand Up @@ -1621,30 +1622,40 @@ cdef WriteBuffer remap_arguments(
offset += param_format_count * 2

# find positions of external args
param_count_external = read_int16(data[offset:offset+2])
arg_count_external = read_int16(data[offset:offset+2])
offset += 2
param_pos_external = []
for p in range(param_count_external):
arg_offset_external = offset
for p in range(arg_count_external):
size = read_int32(data[offset:offset+4])
if size == -1: # special case: NULL
size = 0
size += 4 # for size which is int32
param_pos_external.append((offset, size))
offset += size

# write remapped args
max_external_used = 0
if params:
buf.write_int16(len(params))
for param in params:
if isinstance(param, dbstate.SQLParamExternal):
# map external arg to internal
o, s = param_pos_external[param.index - 1]
buf.write_bytes(data[o:o+s])

if max_external_used < param.index:
max_external_used = param.index
elif isinstance(param, dbstate.SQLParamGlobal):
param_count_external = 0
for i, param in enumerate(params):
if not isinstance(param, dbstate.SQLParamExternal):
break
param_count_external = i + 1
if param_count_external != arg_count_external:
raise pgerror.new(
pgerror.ERROR_PROTOCOL_VIOLATION,
f'bind message supplies {arg_count_external} '
f'parameters, but prepared statement "" requires '
f'{param_count_external}',
)

# write external args
if arg_offset_external < offset:
buf.write_bytes(data[arg_offset_external:offset])

# write global's args
for param in params[param_count_external:]:
if isinstance(param, dbstate.SQLParamGlobal):
name = param.global_name
setting_name = f'global {name.module}::{name.name}'
values = fe_settings.get(setting_name, None)
Expand All @@ -1656,14 +1667,7 @@ cdef WriteBuffer remap_arguments(
else:
buf.write_int16(0)

if max_external_used != param_count_external:
raise pgerror.new(
pgerror.ERROR_PROTOCOL_VIOLATION,
f'bind message supplies {param_count_external} '
f'parameters, but prepared statement "" requires '
f'{max_external_used}',
)

# result format codes
buf.write_bytes(data[offset:])
return buf

Expand Down