Skip to content

Commit

Permalink
Update SqlRewritingSubQueryReducer to support CTEs.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 4, 2024
1 parent fbd25ff commit 8c2a073
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
34 changes: 23 additions & 11 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def _reduce_parents(
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
right_source_alias=x.right_source_alias,
on_condition=x.on_condition,
join_type=x.join_type,
right_source=join_desc.right_source.accept(self),
right_source_alias=join_desc.right_source_alias,
on_condition=join_desc.on_condition,
join_type=join_desc.join_type,
)
for x in node.join_descs
for join_desc in node.join_descs
),
group_bys=node.group_bys,
order_bys=node.order_bys,
Expand Down Expand Up @@ -199,7 +199,7 @@ def _is_simple_source(node: SqlSelectStatementNode) -> bool:
if select_column.expr.lineage.contains_aggregate_exprs:
return False
return (
len(node.parent_nodes) <= 1
len(node.join_descs) == 0
and len(node.group_bys) == 0
and len(node.order_bys) == 0
and not node.limit
Expand Down Expand Up @@ -419,8 +419,7 @@ def _find_matching_select_column(
return select_column
return None

@staticmethod
def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementNode:
def _rewrite_node_with_join(self, node: SqlSelectStatementNode) -> SqlSelectStatementNode:
"""Reduces nodes with joins if the join source is simple to reduce.
Converts this:
Expand Down Expand Up @@ -579,6 +578,13 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
select_columns=tuple(clauses_to_rewrite.select_columns),
from_source=from_source,
from_source_alias=from_source_alias,
cte_sources=tuple(
SqlCteNode.create(
cte_alias=cte_source.cte_alias,
select_statement=cte_source.select_statement.accept(self),
)
for cte_source in node.cte_sources
),
join_descs=tuple(new_join_descs),
group_bys=tuple(clauses_to_rewrite.group_bys),
order_bys=tuple(clauses_to_rewrite.order_bys),
Expand All @@ -595,7 +601,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
node_with_reduced_parents = self._reduce_parents(node)

if len(node_with_reduced_parents.join_descs) > 0:
return SqlRewritingSubQueryReducerVisitor._rewrite_node_with_join(node_with_reduced_parents)
return self._rewrite_node_with_join(node_with_reduced_parents)

if not self._current_node_can_be_reduced(node_with_reduced_parents):
return node_with_reduced_parents
Expand Down Expand Up @@ -699,6 +705,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
),
from_source=from_source_select_node.from_source,
from_source_alias=from_source_select_node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=from_source_select_node.join_descs,
group_bys=new_group_bys,
order_bys=tuple(new_order_bys),
Expand Down Expand Up @@ -733,13 +742,13 @@ def _find_matching_select(
) -> Optional[SqlSelectColumn]:
"""Given an expression, find the SELECT column that has the same expression."""
for select_column in select_columns:
if select_column.expr == expr:
if select_column.expr.matches(expr):
return select_column
return None

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
return node.with_new_select(node.select_statement.accept(self))

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
new_group_bys = []
Expand Down Expand Up @@ -767,6 +776,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_columns=node.select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def test_reducing_join_statement(
mf_test_configuration: MetricFlowTestConfiguration,
reducing_join_statement: SqlSelectStatementNode,
) -> None:
"""Tests a case where a join query should not reduced an aggregate."""
"""Tests a case where a join query should not reduce an aggregate."""
assert_default_rendered_sql_equal(
request=request,
mf_test_configuration=mf_test_configuration,
Expand Down

0 comments on commit 8c2a073

Please sign in to comment.