Skip to content

Commit

Permalink
Implement position mapping for SQL over native proto (#8116)
Browse files Browse the repository at this point in the history
This is a little fiddly, since we need to account for query
normalization. We do this by traversing the pre and post normalization
ASTs in parallel, building up a map.  In my testing, postgres only
returned error positions for errors that could be detected while
describing the query, so to simplify the implementation we only remap
errors that happen there.

Also do a better job translating the errors:
 * Don't ISE when it doesn't map to an edgedb error
 * Preserve hint and details when we can

Fixes #8077.
  • Loading branch information
msullivan authored Dec 13, 2024
1 parent 4636d06 commit 47efb6e
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 47 deletions.
41 changes: 25 additions & 16 deletions edb/pgsql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from typing import Any, Optional, Sequence, List

import abc
import collections
import dataclasses

Expand Down Expand Up @@ -113,33 +114,41 @@ def generate_ctes_source(


class TranslationData:
@abc.abstractmethod
def translate(self, pos: int) -> int:
...


@dataclasses.dataclass(kw_only=True)
class BaseTranslationData(TranslationData):
source_start: int
output_start: int
output_end: int
children: List[TranslationData]

def __init__(
self,
*,
source_start: int,
output_start: int,
):
self.source_start = source_start
self.output_start = output_start
self.output_end = -1
self.children = []
output_end: int | None = None
children: List[BaseTranslationData] = (
dataclasses.field(default_factory=list))

def translate(self, pos: int) -> int:
bu = None
for u in self.children:
if u.output_start >= pos:
break
bu = u
if bu and bu.output_end > pos:
if bu and (bu.output_end is None or bu.output_end > pos):
return bu.translate(pos)
return self.source_start


@dataclasses.dataclass
class ChainedTranslationData(TranslationData):
parts: List[TranslationData] = (
dataclasses.field(default_factory=list))

def translate(self, pos: int) -> int:
for part in self.parts:
pos = part.translate(pos)
return pos


@dataclasses.dataclass(frozen=True)
class SQLSource:
text: str
Expand Down Expand Up @@ -168,7 +177,7 @@ def __init__(
self.param_index: collections.defaultdict[int, list[int]] = (
collections.defaultdict(list))
self.write_index: int = 0
self.translation_data: Optional[TranslationData] = None
self.translation_data: Optional[BaseTranslationData] = None

def write(
self,
Expand All @@ -182,7 +191,7 @@ def write(

def visit(self, node): # type: ignore
if self.with_translation_data:
translation_data = TranslationData(
translation_data = BaseTranslationData(
source_start=node.span.start if node.span else 0,
output_start=self.write_index,
)
Expand Down
3 changes: 3 additions & 0 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2672,6 +2672,7 @@ def compile_sql_as_unit_group(
f"unexpected SQLQueryUnit.command_complete_tag type: "
f"{sql_unit.command_complete_tag}"
)

unit = dbstate.QueryUnit(
sql=value_sql,
introspection_sql=intro_sql,
Expand All @@ -2687,6 +2688,8 @@ def compile_sql_as_unit_group(
if sql_unit.cardinality is enums.Cardinality.NO_RESULT
else enums.OutputFormat.BINARY
),
translation_data=sql_unit.translation_data,
sql_prefix_len=sql_unit.prefix_len,
)
match sql_unit.tx_action:
case dbstate.TxAction.START:
Expand Down
9 changes: 6 additions & 3 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ class QueryUnit:
run_and_rollback: bool = False
append_tx_op: bool = False

# Translation source map.
translation_data: Optional[pgcodegen.TranslationData] = None
# For SQL queries, the length of the query prefix applied
# after translation.
sql_prefix_len: int = 0

@property
def has_ddl(self) -> bool:
return bool(self.capabilities & enums.Capability.DDL)
Expand Down Expand Up @@ -524,9 +530,6 @@ class SQLQueryUnit:
repr=False, default=None)
"""Translated query text returning data in single-column format."""

eql_format_translation_data: Optional[pgcodegen.TranslationData] = None
"""Translation source map for single-column format query."""

orig_query: str = dataclasses.field(repr=False)
"""Original query text before translation."""

Expand Down
76 changes: 72 additions & 4 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json

from edb import errors
from edb.common import ast
from edb.common import uuidgen
from edb.server import defines

Expand Down Expand Up @@ -76,6 +77,7 @@ def compile_sql(
def _try(q: str) -> List[dbstate.SQLQueryUnit]:
return _compile_sql(
q,
orig_query_str=source.original_text(),
schema=schema,
tx_state=tx_state,
prepared_stmt_map=prepared_stmt_map,
Expand Down Expand Up @@ -109,9 +111,58 @@ def _try(q: str) -> List[dbstate.SQLQueryUnit]:
raise original_err


def _build_constant_extraction_map(
src: pgast.Base,
out: pgast.Base,
) -> pg_codegen.BaseTranslationData:
"""Traverse two ASTs in parallel and build a source map between them.
The ASTs should *mostly* line up. When they don't, that is
considered a leaf.
This is used to translate SQL spans reported on a normalized query
to ones that make sense on the pre-normalization version.
Note that we only use this map for errors reported during the
"parse" phase, so we don't need to worry about it being reused
with different constants.
"""
tdata = pg_codegen.BaseTranslationData(
source_start=src.span.start if src.span else 0,
# HACK: I don't know why, but this - 1 helps a lot.
output_start=out.span.start - 1 if out.span else 0,
)
if type(src) != type(out):
return tdata
children = tdata.children
for (k1, v1), (k2, v2) in zip(ast.iter_fields(src), ast.iter_fields(out)):
assert k1 == k2

if isinstance(v1, pgast.Base) and isinstance(v2, pgast.Base):
children.append(_build_constant_extraction_map(v1, v2))
elif (
isinstance(v1, (tuple, list)) and isinstance(v2, (tuple, list))
):
for v1e, v2e in zip(v1, v2):
if isinstance(v1e, pgast.Base) and isinstance(v2e, pgast.Base):
children.append(_build_constant_extraction_map(v1e, v2e))
elif (
isinstance(v1, dict) and isinstance(v2, dict)
):
for k, v1e in v1.items():
v2e = v2.get(k)
if isinstance(v1e, pgast.Base) and isinstance(v2e, pgast.Base):
children.append(_build_constant_extraction_map(v1e, v2e))

children.sort(key=lambda k: k.output_start)

return tdata


def _compile_sql(
query_str: str,
*,
orig_query_str: Optional[str] = None,
schema: s_schema.Schema,
tx_state: dbstate.SQLTransactionState,
prepared_stmt_map: Mapping[str, str],
Expand All @@ -137,13 +188,21 @@ def _compile_sql(
disambiguate_column_names=disambiguate_column_names,
)

# orig_stmts are the statements prior to constant extraction
stmts = pg_parser.parse(query_str, propagate_spans=True)
if orig_query_str and orig_query_str != query_str:
orig_stmts = pg_parser.parse(orig_query_str, propagate_spans=True)
else:
orig_stmts = stmts

sql_units = []
for stmt in stmts:
for stmt, orig_stmt in zip(stmts, orig_stmts):
orig_text = pg_codegen.generate_source(stmt)
fe_settings = tx_state.current_fe_settings()
track_stats = False

extract_data = _build_constant_extraction_map(orig_stmt, stmt)

unit = dbstate.SQLQueryUnit(
orig_query=orig_text,
fe_settings=fe_settings,
Expand Down Expand Up @@ -351,11 +410,20 @@ def _compile_sql(
)
unit.query = stmt_source.text
unit.translation_data = stmt_source.translation_data
if stmt_source.translation_data:
unit.translation_data = (
pg_codegen.ChainedTranslationData([
stmt_source.translation_data,
extract_data,
])
)

if edgeql_fmt_src is not None:
unit.eql_format_query = edgeql_fmt_src.text
unit.eql_format_translation_data = (
edgeql_fmt_src.translation_data
)
# We don't do anything with the translation data for
# this query, since postgres typically doesn't report
# out error positions that didn't get reported during
# the "parse" phase.
unit.command_complete_tag = stmt_resolved.command_complete_tag
unit.params = stmt_resolved.params
if isinstance(stmt, pgast.DMLQuery) and not stmt.returning_list:
Expand Down
19 changes: 16 additions & 3 deletions edb/server/dbview/dbview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ from edb.server.compiler import dbstate, enums, sertypes
from edb.server.protocol import execute
from edb.pgsql import dbops
from edb.server.compiler_pool import state as compiler_state_mod
from edb.server.pgcon import errors as pgerror

from edb.server.protocol import ai_ext

Expand Down Expand Up @@ -1457,9 +1458,21 @@ cdef class DatabaseConnectionView:

intro_sql = query_unit.introspection_sql
if intro_sql is None:
intro_sql = query_unit.sql[0]
param_desc, result_desc = await pgcon.sql_describe(
intro_sql, all_type_oids)
intro_sql = query_unit.sql
try:
param_desc, result_desc = await pgcon.sql_describe(
intro_sql, all_type_oids)
except pgerror.BackendError as ex:
ex._from_sql = True
if 'P' in ex.fields:
ex.fields['P'] = str(
int(ex.fields['P']) - query_unit.sql_prefix_len
)
if query_unit.translation_data:
ex._translation_data = query_unit.translation_data

raise

result_types = []
for col, toid in result_desc:
edb_type_id = self._db.backend_oid_to_id.get(toid)
Expand Down
41 changes: 39 additions & 2 deletions edb/server/protocol/execute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ from edb.server import compiler
from edb.server import config
from edb.server import defines as edbdef
from edb.server import metrics
from edb.server.compiler import dbstate
from edb.server.compiler import errormech
from edb.server.compiler cimport rpc
from edb.server.compiler import sertypes
Expand Down Expand Up @@ -379,6 +380,8 @@ async def execute(
if query_unit.user_schema:
if isinstance(ex, pgerror.BackendError):
ex._user_schema = query_unit.user_schema
if query_unit.translation_data:
ex._from_sql = True

dbv.on_error()

Expand Down Expand Up @@ -439,6 +442,7 @@ async def execute_script(
global_schema = roles = None
unit_group = compiled.query_unit_group
query_prefix = compiled.make_query_prefix()
query_unit = None

sync = False
no_sync = False
Expand Down Expand Up @@ -570,6 +574,9 @@ async def execute_script(
if isinstance(e, pgerror.BackendError):
e._user_schema = dbv.get_user_schema_pickle()

if query_unit and query_unit.translation_data:
e._from_sql = True

if not in_tx and dbv.in_tx():
# Abort the implicit transaction
dbv.abort_tx()
Expand Down Expand Up @@ -931,8 +938,12 @@ async def interpret_error(

elif isinstance(exc, pgerror.BackendError):
try:
from_sql = getattr(exc, '_from_sql', False)
translation_data = getattr(exc, '_translation_data', None)
fields = exc.fields

static_exc = errormech.static_interpret_backend_error(
exc.fields, from_graphql=from_graphql
fields, from_graphql=from_graphql
)

# only use the backend if schema is required
Expand All @@ -950,7 +961,7 @@ async def interpret_error(
exc = await compiler_pool.interpret_backend_error(
user_schema_pickle,
global_schema_pickle,
exc.fields,
fields,
from_graphql,
)

Expand All @@ -963,6 +974,32 @@ async def interpret_error(
else:
exc = static_exc

if from_sql and isinstance(exc, errors.InternalServerError):
exc = errors.ExecutionError(*exc.args)

# Translate error position for SQL queries if we can
if translation_data and isinstance(exc, errors.EdgeDBError):
if 'P' in fields:
exc.set_position(
0,
0,
translation_data.translate(int(fields['P'])),
None,
)

# Include hint/detail from SQL queries also, if we haven't
# produced our own.
if from_sql and isinstance(exc, errors.EdgeDBError):
if 'H' in fields or 'D' in fields:
hint = exc.hint or fields.get('H')
details = exc.details or fields.get('D')
# ... there is some sort of cython bug/"feature"
# involving the type annotation above which causes
# exc.set_hint_and_details to fail, so we copy it
# to a new variable.
exc2: object = exc
exc2.set_hint_and_details(hint, details)

except Exception as e:
from edb.common import debug
if debug.flags.server:
Expand Down
Loading

0 comments on commit 47efb6e

Please sign in to comment.