Skip to content

Commit

Permalink
Try non-normalized source on a SQL error (#8079)
Browse files Browse the repository at this point in the history
Fixes part of #8077. We still need to handle backend spans.
  • Loading branch information
msullivan authored Dec 6, 2024
1 parent 696bf13 commit dfd3c0c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
6 changes: 6 additions & 0 deletions edb/pgsql/parser/parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ cdef class Source:
def text(self) -> str:
return self._text

def original_text(self) -> str:
return self._text

def cache_key(self) -> bytes:
if not self._cache_key:
h = hashlib.blake2b(self._tag().to_bytes())
Expand Down Expand Up @@ -281,6 +284,9 @@ cdef class NormalizedSource(Source):
def _tag(cls) -> int:
return 1

def original_text(self) -> str:
return self._orig_text

cdef WriteBuffer _serialize(self):
cdef WriteBuffer buf

Expand Down
7 changes: 5 additions & 2 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from edb.pgsql import debug as pg_debug
from edb.pgsql import dbops as pg_dbops
from edb.pgsql import params as pg_params
from edb.pgsql import parser as pg_parser
from edb.pgsql import patches as pg_patches
from edb.pgsql import types as pg_types
from edb.pgsql import delta as pg_delta
Expand Down Expand Up @@ -557,8 +558,10 @@ def compile_sql(
if setting and setting.value:
apply_access_policies_pg = sql.is_setting_truthy(setting.value)

query_source = pg_parser.Source(query_str)

return sql.compile_sql(
query_str,
query_source,
schema=schema,
tx_state=tx_state,
prepared_stmt_map=prepared_stmt_map,
Expand Down Expand Up @@ -2509,7 +2512,7 @@ def compile_sql_as_unit_group(
)

sql_units = sql.compile_sql(
source.text(),
source,
schema=schema,
tx_state=sql_tx_state,
prepared_stmt_map={},
Expand Down
52 changes: 52 additions & 0 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,58 @@


def compile_sql(
source: pg_parser.Source,
*,
schema: s_schema.Schema,
tx_state: dbstate.SQLTransactionState,
prepared_stmt_map: Mapping[str, str],
current_database: str,
current_user: str,
allow_user_specified_id: Optional[bool],
apply_access_policies_sql: Optional[bool],
include_edgeql_io_format_alternative: bool = False,
allow_prepared_statements: bool = True,
disambiguate_column_names: bool,
backend_runtime_params: pg_params.BackendRuntimeParams,
protocol_version: defines.ProtocolVersion,
) -> List[dbstate.SQLQueryUnit]:
def _try(q: str) -> List[dbstate.SQLQueryUnit]:
return _compile_sql(
q,
schema=schema,
tx_state=tx_state,
prepared_stmt_map=prepared_stmt_map,
current_database=current_database,
current_user=current_user,
allow_user_specified_id=allow_user_specified_id,
apply_access_policies_sql=apply_access_policies_sql,
include_edgeql_io_format_alternative=(
include_edgeql_io_format_alternative),
allow_prepared_statements=allow_prepared_statements,
disambiguate_column_names=disambiguate_column_names,
backend_runtime_params=backend_runtime_params,
protocol_version=protocol_version,
)

try:
return _try(source.text())
except errors.EdgeDBError as original_err:
if isinstance(source, pg_parser.NormalizedSource):
# try non-normalized source
try:
_try(source.original_text())
except errors.EdgeDBError as denormalized_err:
raise denormalized_err
except Exception:
raise original_err
else:
raise AssertionError(
"Normalized query is broken while original is valid")
else:
raise original_err


def _compile_sql(
query_str: str,
*,
schema: s_schema.Schema,
Expand Down
39 changes: 36 additions & 3 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2548,7 +2548,7 @@ async def test_native_sql_query_04(self):
with self.assertRaisesRegex(
edgedb.errors.QueryError,
'duplicate column name: `a`',
_position=16,
_position=15,
):
await self.assert_sql_query_result('SELECT 1 AS a, 2 AS a', [])

Expand Down Expand Up @@ -2581,7 +2581,7 @@ async def test_native_sql_query_07(self):
with self.assertRaisesRegex(
edgedb.errors.QueryError,
'duplicate column name: `y_a`',
# _position=114, TODO: spans are messed up somewhere
_position=137,
):
await self.assert_sql_query_result(
'''
Expand All @@ -2597,7 +2597,7 @@ async def test_native_sql_query_08(self):
with self.assertRaisesRegex(
edgedb.errors.QueryError,
'duplicate column name: `x_a`',
# _position=83, TODO: spans are messed up somewhere
_position=92,
):
await self.assert_sql_query_result(
'''
Expand Down Expand Up @@ -2859,3 +2859,36 @@ async def test_native_sql_query_17(self):
},
apply_access_policies=False,
)

async def test_native_sql_query_18(self):
with self.assertRaisesRegex(
edgedb.errors.QueryError,
'cannot find column `asdf`',
_position=35,
):
await self.con.query_sql(
'''select title, 'aaaaaaaaaaaaaaaaa', asdf from "Content";'''
)

@test.xerror('See #8077')
async def test_native_sql_query_19(self):
with self.assertRaisesRegex(
edgedb.errors.QueryError,
'',
_position=37,
):
await self.con.query_sql(
'''select title, 'aaaaaaaaaaaaaaaaa', asdf() from "Content";'''
)

@test.xfail('See #8077')
async def test_native_sql_query_20(self):
with self.assertRaisesRegex(
edgedb.errors.InvalidValueError,
'invalid input syntax for type integer',
_position=35,
):
await self.con.query_sql(
'''\
select title, 'aaaaaaaaaaaaaaaaa', ('goo'::text::integer) from "Content";'''
)

0 comments on commit dfd3c0c

Please sign in to comment.