Skip to content

Commit

Permalink
Fix RECURSIVE over SQL adapter (#8132)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Dec 17, 2024
1 parent 298a43c commit 8d37e82
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 11 deletions.
2 changes: 1 addition & 1 deletion edb/pgsql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def gen_ctes(self, ctes: List[pgast.CommonTableExpr]) -> None:
count = len(ctes)
for i, cte in enumerate(ctes):
self.new_lines = 1
if getattr(cte, 'recursive', None):
if i == 0 and getattr(cte, 'recursive', None):
self.write('RECURSIVE ')
self.write(common.quote_ident(cte.name))

Expand Down
17 changes: 7 additions & 10 deletions edb/pgsql/parser/ast_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,12 @@ def _build_record_indirection_op(


def _build_ctes(n: Node, c: Context) -> List[pgast.CommonTableExpr]:
return _list(n, c, "ctes", _build_cte)
is_recursive = _maybe(n, c, 'recursive', lambda x, _: bool(x)) or False

ctes: List[pgast.CommonTableExpr] = _list(n, c, "ctes", _build_cte)
for cte in ctes:
cte.recursive = is_recursive
return ctes


def _build_cte(n: Node, c: Context) -> pgast.CommonTableExpr:
Expand All @@ -660,18 +665,10 @@ def _build_cte(n: Node, c: Context) -> pgast.CommonTableExpr:
elif n["ctematerialized"] == "CTEMaterializeNever":
materialized = False

recursive = _bool_or_false(n, "cterecursive")

# workaround because libpg_query does not actually emit cterecursive
if "cterecursive" not in n:
location = n["location"]
if 'RECURSIVE' in c.source_sql[:location][-15:].upper():
recursive = True

return pgast.CommonTableExpr(
name=n["ctename"],
query=_build_query(n["ctequery"], c),
recursive=recursive,
recursive=False,
aliascolnames=_maybe_list(n, c, "aliascolnames", _build_str),
materialized=materialized,
span=_build_span(n, c),
Expand Down
95 changes: 95 additions & 0 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,101 @@ async def test_sql_query_55(self):
)
]])

async def test_sql_query_56(self):
# recursive

res = await self.squery_values(
'''
WITH RECURSIVE
integers(n) AS (
SELECT 0
UNION ALL
SELECT n + 1 FROM integers
WHERE n + 1 < 5
)
SELECT n FROM integers
''',
)
self.assertEqual(res, [
[0],
[1],
[2],
[3],
[4],
])

res = await self.squery_values(
'''
WITH RECURSIVE
fibonacci(n, prev, val) AS (
SELECT 1, 0, 1
UNION ALL
SELECT n + 1, val, prev + val
FROM fibonacci
WHERE n + 1 < 10
)
SELECT n, val FROM fibonacci;
'''
)
self.assertEqual(res, [
[1, 1],
[2, 1],
[3, 2],
[4, 3],
[5, 5],
[6, 8],
[7, 13],
[8, 21],
[9, 34],
])

res = await self.squery_values(
'''
WITH RECURSIVE
fibonacci(n, prev, val) AS (
SELECT 1, 0, 1
UNION ALL
SELECT n + 1, val, prev + val
FROM fibonacci
WHERE n + 1 < 8
),
integers(n) AS (
SELECT 0
UNION ALL
SELECT n + 1 FROM integers
WHERE n + 1 < 5
)
SELECT f.n, f.val FROM fibonacci f, integers i where f.n = i.n;
'''
)
self.assertEqual(res, [
[1, 1],
[2, 1],
[3, 2],
[4, 3],
])

res = await self.squery_values(
'''
WITH RECURSIVE
a as (SELECT 12 as n),
integers(n) AS (
SELECT 0
UNION ALL
SELECT n + 1 FROM integers
WHERE n + 1 < 5
)
SELECT * FROM a, integers;
'''
)
self.assertEqual(res, [
[12, 0],
[12, 1],
[12, 2],
[12, 3],
[12, 4],
])

async def test_sql_query_introspection_00(self):
dbname = self.con.dbname
res = await self.squery_values(
Expand Down

0 comments on commit 8d37e82

Please sign in to comment.