Skip to content

Commit

Permalink
Factoring of columns in USING clause over SQL adapter (#7923)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Oct 30, 2024
1 parent 0120521 commit 876ac76
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 1 deletion.
13 changes: 13 additions & 0 deletions edb/pgsql/resolver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ class Scope:
# Common Table Expressions
ctes: List[CTE] = field(default_factory=lambda: [])

# Pairs of columns of the same name that have been compared in a USING
# clause. This makes unqualified references to their name them un-ambiguous.
# The fourth tuple element is the join type.
factored_columns: List[Tuple[str, Table, Table, str]] = field(
default_factory=lambda: []
)


@dataclass(kw_only=True)
class Table:
Expand Down Expand Up @@ -151,6 +158,12 @@ class ColumnComputable(ColumnKind):
pointer: s_pointers.Pointer


@dataclass(kw_only=True)
class ColumnPgExpr(ColumnKind):
# Value that was provided by some special resolver path.
expr: pgast.BaseExpr


@dataclass(kw_only=True, eq=False, slots=True, repr=False)
class CompiledDML:
# relation that provides the DML value. not yet resolved.
Expand Down
42 changes: 41 additions & 1 deletion edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def resolve_column_kind(
case context.ColumnStaticVal(val=val):
# special case: __type__ static value
return _uuid_const(val)
case context.ColumnPgExpr(expr=e):
return e
case context.ColumnComputable(pointer=pointer):

expr = pointer.get_expr(ctx.schema)
Expand Down Expand Up @@ -268,6 +270,44 @@ def _lookup_column(
(t, c) for t, c in matched_columns if t.precedence == max_precedence
]

# when ambiguous references have been used in USING clause,
# we resolve them to first or the second column or a COALESCE of the two.
if (
len(matched_columns) == 2
and matched_columns[0][1].name == matched_columns[1][1].name
):
matched_name = matched_columns[0][1].name
matched_tables = [t for t, _c in matched_columns]

for c_name, t_left, t_right, join_type in ctx.scope.factored_columns:
if matched_name != c_name:
continue
if not (t_left in matched_tables and t_right in matched_tables):
continue

c_left = next(c for c in t_left.columns if c.name == c_name)
c_right = next(c for c in t_right.columns if c.name == c_name)

if join_type == 'INNER' or join_type == 'LEFT':
matched_columns = [(t_left, c_left)]
elif join_type == 'RIGHT':
matched_columns = [(t_right, c_right)]
elif join_type == 'FULL':
coalesce = pgast.CoalesceExpr(
args=[
resolve_column_kind(t_left, c_left.kind, ctx=ctx),
resolve_column_kind(t_right, c_right.kind, ctx=ctx),
]
)
c_coalesce = context.Column(
name=c_name,
kind=context.ColumnPgExpr(expr=coalesce),
)
matched_columns = [(t_left, c_coalesce)]
else:
raise NotImplementedError()
break

if len(matched_columns) > 1:
potential_tables = ', '.join([t.name or '' for t, _ in matched_columns])
raise errors.QueryError(
Expand All @@ -276,7 +316,7 @@ def _lookup_column(
span=column_ref.span,
)

return (matched_columns[0],)
return matched_columns


def _lookup_in_table(
Expand Down
9 changes: 9 additions & 0 deletions edb/pgsql/resolver/range_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,21 @@ def _resolve_JoinExpr(

if join.using_clause:
for c in join.using_clause:
assert len(c.name) == 1
assert isinstance(c.name[-1], str)
c_name = c.name[-1]

with ctx.child() as subctx:
subctx.scope.tables = [ltable]
l_expr = dispatch.resolve(c, ctx=subctx)
with ctx.child() as subctx:
subctx.scope.tables = [rtable]
r_expr = dispatch.resolve(c, ctx=subctx)

ctx.scope.factored_columns.append(
(c_name, ltable, rtable, join.type)
)

quals = pgastutils.extend_binop(
quals,
pgast.Expr(
Expand Down
69 changes: 69 additions & 0 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,75 @@ async def test_sql_query_42(self):
self.assertEqual(res, 'UPDATE 1')
await tran.rollback()

async def test_sql_query_43(self):
# USING factoring

res = await self.squery_values(
'''
WITH
a(id) AS (SELECT 1 UNION SELECT 2),
b(id) AS (SELECT 1 UNION SELECT 3)
SELECT a.id, b.id, id
FROM a LEFT JOIN b USING (id);
'''
)
self.assertEqual(res, [[1, 1, 1], [2, None, 2]])

res = await self.squery_values(
'''
WITH
a(id, sub_id) AS (SELECT 1, 'a' UNION SELECT 2, 'b'),
b(id, sub_id) AS (SELECT 1, 'a' UNION SELECT 3, 'c')
SELECT a.id, a.sub_id, b.id, b.sub_id, id, sub_id
FROM a JOIN b USING (id, sub_id);
'''
)
self.assertEqual(res, [[1, 'a', 1, 'a', 1, 'a']])

res = await self.squery_values(
'''
WITH
a(id) AS (SELECT 1 UNION SELECT 2),
b(id) AS (SELECT 1 UNION SELECT 3)
SELECT a.id, b.id, id
FROM a INNER JOIN b USING (id);
'''
)
self.assertEqual(res, [[1, 1, 1]])

res = await self.squery_values(
'''
WITH
a(id) AS (SELECT 1 UNION SELECT 2),
b(id) AS (SELECT 1 UNION SELECT 3)
SELECT a.id, b.id, id
FROM a RIGHT JOIN b USING (id);
'''
)
self.assertEqual(res, [[1, 1, 1], [None, 3, 3]])

res = await self.squery_values(
'''
WITH
a(id) AS (SELECT 1 UNION SELECT 2),
b(id) AS (SELECT 1 UNION SELECT 3)
SELECT a.id, b.id, id
FROM a RIGHT OUTER JOIN b USING (id);
'''
)
self.assertEqual(res, [[1, 1, 1], [None, 3, 3]])

res = await self.squery_values(
'''
WITH
a(id) AS (SELECT 1 UNION SELECT 2),
b(id) AS (SELECT 1 UNION SELECT 3)
SELECT a.id, b.id, id
FROM a FULL JOIN b USING (id);
'''
)
self.assertEqual(res, [[1, 1, 1], [2, None, 2], [None, 3, 3]])

async def test_sql_query_introspection_00(self):
dbname = self.con.dbname
res = await self.squery_values(
Expand Down

0 comments on commit 876ac76

Please sign in to comment.