Skip to content

Commit

Permalink
Propagate spans in parsed SQL (#7466)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Jun 18, 2024
1 parent 772e120 commit 5eacce4
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 36 deletions.
52 changes: 30 additions & 22 deletions edb/common/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from __future__ import annotations

from typing import List
from typing import Iterable, List
import re
import bisect

Expand Down Expand Up @@ -176,16 +176,20 @@ def get_span(*kids: List[ast.AST]):
)


def merge_spans(spans: List[Span]) -> Span:
spans.sort(key=lambda x: (x.start, x.end))
def merge_spans(spans: Iterable[Span]) -> Span | None:
span_list = list(spans)
if not span_list:
return None

span_list.sort(key=lambda x: (x.start, x.end))

# assume same name and buffer apply to all
#
return Span(
name=spans[0].name,
buffer=spans[0].buffer,
start=spans[0].start,
end=spans[-1].end,
name=span_list[0].name,
buffer=span_list[0].buffer,
start=span_list[0].start,
end=span_list[-1].end,
)


Expand Down Expand Up @@ -233,43 +237,47 @@ class SpanPropagator(ast.NodeVisitor):
also have correct span. For a node that has no span, its
span is derived as a superset of all of the spans of its
descendants.
If full_pass is True, nodes with span will still recurse into
children and their new span will also be superset of the existing span.
"""

def __init__(self, default=None):
def __init__(self, default=None, full_pass=False):
super().__init__()
self._default = default
self._full_pass = full_pass

def container_visit(self, node):
def repeated_node_visit(self, node):
return self.memo[node]

def container_visit(self, node) -> List[Span | None]:
span_list = []
for el in node:
if isinstance(el, ast.AST) or typeutils.is_container(el):
span = self.visit(el)

if isinstance(span, list):
if not span:
pass
elif isinstance(span, list):
span_list.extend(span)
else:
span_list.append(span)
return span_list

def generic_visit(self, node):
# base case: we already have span
if getattr(node, 'span', None) is not None:
if not self._full_pass and getattr(node, 'span', None) is not None:
return node.span

# we need to derive span based on the children
# recurse into children fields
span_list = self.container_visit(v for _, v in ast.iter_fields(node))

if None in span_list:
node.dump()
print(list(ast.iter_fields(node)))
# also include own span (this can only happen in full_pass)
if existing := getattr(node, 'span', None):
span_list.append(existing)

# now that we have all of the children spans, let's merge
# them into one
#
if span_list:
node.span = merge_spans(span_list)
else:
node.span = self._default
# merge spans into one
node.span = merge_spans(s for s in span_list if s) or self._default

return node.span

Expand Down
4 changes: 2 additions & 2 deletions edb/pgsql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def dump_sql(self) -> None:


class ImmutableBase(ast.ImmutableASTMixin, Base):
pass
__ast_mutable_fields__ = frozenset(['span'])


class Alias(ImmutableBase):
Expand Down Expand Up @@ -212,7 +212,7 @@ class BaseRangeVar(ImmutableBaseExpr):
"""

__ast_meta__ = {'schema_object_id', 'tag', 'ir_origins'}
__ast_mutable_fields__ = frozenset(['ir_origins'])
__ast_mutable_fields__ = frozenset(['ir_origins', 'span'])

# This is a hack, since there is some code that relies on not
# having an alias on a range var (to refer to a CTE directly, for
Expand Down
6 changes: 4 additions & 2 deletions edb/pgsql/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from .ast_builder import build_stmts


def parse(sql_query: str) -> List[pgast.Query | pgast.Statement]:
def parse(
sql_query: str, propagate_spans: bool = False
) -> List[pgast.Query | pgast.Statement]:
ast_json = pg_parse(bytes(sql_query, encoding="UTF8"))

return build_stmts(json.loads(ast_json), sql_query)
return build_stmts(json.loads(ast_json), sql_query, propagate_spans)
12 changes: 10 additions & 2 deletions edb/pgsql/parser/ast_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
cast,
)

from edb.common import span
from edb.common.parsing import Span

from edb.pgsql import ast as pgast
Expand All @@ -52,12 +53,12 @@ class Context:


def build_stmts(
node: Node, source_sql: str
node: Node, source_sql: str, propagate_spans: bool
) -> List[pgast.Query | pgast.Statement]:
ctx = Context(source_sql=source_sql)

try:
return [_build_stmt(node["stmt"], ctx) for node in node["stmts"]]
res = [_build_stmt(node["stmt"], ctx) for node in node["stmts"]]
except IndexError:
raise PSqlUnsupportedError()
except PSqlUnsupportedError as e:
Expand All @@ -68,6 +69,13 @@ def build_stmts(
e.message += source_sql[e.location : (e.location + 50)]
raise

if propagate_spans:
# we need to do a full pass of span propagation, because some
# nodes (CommonTableExpr) have span, but their children don't (Insert).
span.SpanPropagator(full_pass=True).container_visit(res)

return res


def _maybe(
node: Node, ctx: Context, name: str, builder: Builder
Expand Down
3 changes: 2 additions & 1 deletion edb/pgsql/resolver/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ def register_projections(target_list: List[pgast.ResTarget], *, ctx: Context):
def resolve_DMLQuery(
query: pgast.DMLQuery, *, ctx: Context
) -> Tuple[pgast.DMLQuery, context.Table]:
raise errors.UnsupportedFeatureError(
raise errors.QueryError(
'DML queries (INSERT/UPDATE/DELETE) are not supported',
span=query.span,
pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED,
)


Expand Down
2 changes: 1 addition & 1 deletion edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def compute_stmt_name(text: str) -> str:
'server_version': False,
'server_version_num': False,
}
stmts = pg_parser.parse(query_str)
stmts = pg_parser.parse(query_str, propagate_spans=True)
sql_units = []
for stmt in stmts:
orig_text = pg_gen_source(stmt)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def normalize(s):
else:
expected = source

ast = parser.parse(source)
ast = parser.parse(source, propagate_spans=True)
sql_stmts = [
codegen.generate_source(stmt, pretty=False) for stmt in ast
]
Expand Down
23 changes: 18 additions & 5 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,13 @@ async def test_sql_query_25(self):
await self.scon.fetch('SELECT title FROM "novel" ORDER BY title')

with self.assertRaisesRegex(
asyncpg.UndefinedTableError, "unknown table"
asyncpg.UndefinedTableError, "unknown table", position="19",
):
await self.scon.fetch('SELECT title FROM "Novel" ORDER BY title')

async def test_sql_query_26(self):
with self.assertRaisesRegex(
asyncpg.UndefinedTableError, "unknown table"
asyncpg.UndefinedTableError, "unknown table", position="19",
):
await self.scon.fetch('SELECT title FROM Movie ORDER BY title')

Expand Down Expand Up @@ -446,7 +446,10 @@ async def test_sql_query_30(self):
self.assert_shape(res, 2, ['c', 'd'])

with self.assertRaisesRegex(
asyncpg.InvalidColumnReferenceError, "query resolves to 2"
asyncpg.InvalidColumnReferenceError,
", but the query resolves to 2 columns",
# this points to `1`, because libpg_query does not give better info
position="41",
):
await self.scon.fetch(
'''
Expand Down Expand Up @@ -838,13 +841,13 @@ async def test_sql_query_schemas(self):

await self.scon.execute('SET search_path TO public;')
with self.assertRaisesRegex(
asyncpg.UndefinedTableError, "unknown table"
asyncpg.UndefinedTableError, "unknown table", position="16",
):
await self.squery_values('SELECT id FROM "Item"')

await self.scon.execute('SET search_path TO inventory;')
with self.assertRaisesRegex(
asyncpg.UndefinedTableError, "unknown table"
asyncpg.UndefinedTableError, "unknown table", position="17",
):
await self.scon.fetch('SELECT id FROM "Person";')

Expand Down Expand Up @@ -1354,3 +1357,13 @@ async def test_sql_query_computed_09(self):
SELECT similar_to FROM "Movie"
"""
)

async def test_sql_dml_insert(self):
with self.assertRaisesRegex(
asyncpg.FeatureNotSupportedError, "DML", position="30",
):
await self.scon.fetch(
"""
INSERT INTO "Movie" (title) VALUES ('A man called Ove')
"""
)

0 comments on commit 5eacce4

Please sign in to comment.