From dfd3c0c5eb03a02f8abc8a248c580feb1e2fe3ce Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 6 Dec 2024 10:36:45 -0800 Subject: [PATCH] Try non-normalized source on a SQL error (#8079) Fixes part of #8077. We still need to handle backend spans. --- edb/pgsql/parser/parser.pyx | 6 ++++ edb/server/compiler/compiler.py | 7 +++-- edb/server/compiler/sql.py | 52 +++++++++++++++++++++++++++++++++ tests/test_sql_query.py | 39 +++++++++++++++++++++++-- 4 files changed, 99 insertions(+), 5 deletions(-) diff --git a/edb/pgsql/parser/parser.pyx b/edb/pgsql/parser/parser.pyx index fe0515d47df..b50eb6f3b1b 100644 --- a/edb/pgsql/parser/parser.pyx +++ b/edb/pgsql/parser/parser.pyx @@ -233,6 +233,9 @@ cdef class Source: def text(self) -> str: return self._text + def original_text(self) -> str: + return self._text + def cache_key(self) -> bytes: if not self._cache_key: h = hashlib.blake2b(self._tag().to_bytes()) @@ -281,6 +284,9 @@ cdef class NormalizedSource(Source): def _tag(cls) -> int: return 1 + def original_text(self) -> str: + return self._orig_text + cdef WriteBuffer _serialize(self): cdef WriteBuffer buf diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 448729ecd22..fcf235b36ac 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -91,6 +91,7 @@ from edb.pgsql import debug as pg_debug from edb.pgsql import dbops as pg_dbops from edb.pgsql import params as pg_params +from edb.pgsql import parser as pg_parser from edb.pgsql import patches as pg_patches from edb.pgsql import types as pg_types from edb.pgsql import delta as pg_delta @@ -557,8 +558,10 @@ def compile_sql( if setting and setting.value: apply_access_policies_pg = sql.is_setting_truthy(setting.value) + query_source = pg_parser.Source(query_str) + return sql.compile_sql( - query_str, + query_source, schema=schema, tx_state=tx_state, prepared_stmt_map=prepared_stmt_map, @@ -2509,7 +2512,7 @@ def compile_sql_as_unit_group( ) sql_units = sql.compile_sql( - source.text(), + source, schema=schema, tx_state=sql_tx_state, prepared_stmt_map={}, diff --git a/edb/server/compiler/sql.py b/edb/server/compiler/sql.py index bbff8d35dd5..50d444e42f4 100644 --- a/edb/server/compiler/sql.py +++ b/edb/server/compiler/sql.py @@ -58,6 +58,58 @@ def compile_sql( + source: pg_parser.Source, + *, + schema: s_schema.Schema, + tx_state: dbstate.SQLTransactionState, + prepared_stmt_map: Mapping[str, str], + current_database: str, + current_user: str, + allow_user_specified_id: Optional[bool], + apply_access_policies_sql: Optional[bool], + include_edgeql_io_format_alternative: bool = False, + allow_prepared_statements: bool = True, + disambiguate_column_names: bool, + backend_runtime_params: pg_params.BackendRuntimeParams, + protocol_version: defines.ProtocolVersion, +) -> List[dbstate.SQLQueryUnit]: + def _try(q: str) -> List[dbstate.SQLQueryUnit]: + return _compile_sql( + q, + schema=schema, + tx_state=tx_state, + prepared_stmt_map=prepared_stmt_map, + current_database=current_database, + current_user=current_user, + allow_user_specified_id=allow_user_specified_id, + apply_access_policies_sql=apply_access_policies_sql, + include_edgeql_io_format_alternative=( + include_edgeql_io_format_alternative), + allow_prepared_statements=allow_prepared_statements, + disambiguate_column_names=disambiguate_column_names, + backend_runtime_params=backend_runtime_params, + protocol_version=protocol_version, + ) + + try: + return _try(source.text()) + except errors.EdgeDBError as original_err: + if isinstance(source, pg_parser.NormalizedSource): + # try non-normalized source + try: + _try(source.original_text()) + except errors.EdgeDBError as denormalized_err: + raise denormalized_err + except Exception: + raise original_err + else: + raise AssertionError( + "Normalized query is broken while original is valid") + else: + raise original_err + + +def _compile_sql( query_str: str, *, schema: s_schema.Schema, diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index a4d8221ca1d..a05a3e1e173 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -2548,7 +2548,7 @@ async def test_native_sql_query_04(self): with self.assertRaisesRegex( edgedb.errors.QueryError, 'duplicate column name: `a`', - _position=16, + _position=15, ): await self.assert_sql_query_result('SELECT 1 AS a, 2 AS a', []) @@ -2581,7 +2581,7 @@ async def test_native_sql_query_07(self): with self.assertRaisesRegex( edgedb.errors.QueryError, 'duplicate column name: `y_a`', - # _position=114, TODO: spans are messed up somewhere + _position=137, ): await self.assert_sql_query_result( ''' @@ -2597,7 +2597,7 @@ async def test_native_sql_query_08(self): with self.assertRaisesRegex( edgedb.errors.QueryError, 'duplicate column name: `x_a`', - # _position=83, TODO: spans are messed up somewhere + _position=92, ): await self.assert_sql_query_result( ''' @@ -2859,3 +2859,36 @@ async def test_native_sql_query_17(self): }, apply_access_policies=False, ) + + async def test_native_sql_query_18(self): + with self.assertRaisesRegex( + edgedb.errors.QueryError, + 'cannot find column `asdf`', + _position=35, + ): + await self.con.query_sql( + '''select title, 'aaaaaaaaaaaaaaaaa', asdf from "Content";''' + ) + + @test.xerror('See #8077') + async def test_native_sql_query_19(self): + with self.assertRaisesRegex( + edgedb.errors.QueryError, + '', + _position=37, + ): + await self.con.query_sql( + '''select title, 'aaaaaaaaaaaaaaaaa', asdf() from "Content";''' + ) + + @test.xfail('See #8077') + async def test_native_sql_query_20(self): + with self.assertRaisesRegex( + edgedb.errors.InvalidValueError, + 'invalid input syntax for type integer', + _position=35, + ): + await self.con.query_sql( + '''\ +select title, 'aaaaaaaaaaaaaaaaa', ('goo'::text::integer) from "Content";''' + )