Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CTEs in sub-query reducer #1504

Open
wants to merge 3 commits into
base: p--cte--08
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def _reduce_parents(
select_columns=node.select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
Expand Down Expand Up @@ -150,7 +153,7 @@ def _select_column_for_alias(column_alias: str, select_columns: Sequence[SqlSele
for select_column in select_columns:
if select_column.column_alias == column_alias:
return select_column
raise RuntimeError(f"Column alias '{column_alias}' not in SELECT columns: {select_columns}")
raise RuntimeError(f"Column alias {repr(column_alias)} not in SELECT columns: {select_columns}")

@staticmethod
def _is_simple_source(node: SqlSelectStatementNode) -> bool:
Expand Down Expand Up @@ -591,7 +594,7 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
node_with_reduced_parents = self._reduce_parents(node)

if len(node_with_reduced_parents.parent_nodes) > 1:
if len(node_with_reduced_parents.join_descs) > 0:
return SqlRewritingSubQueryReducerVisitor._rewrite_node_with_join(node_with_reduced_parents)

if not self._current_node_can_be_reduced(node_with_reduced_parents):
Expand All @@ -612,10 +615,11 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# JOIN dim_listings c
# ON a.listing_id = b.listing_id

assert len(node_with_reduced_parents.parent_nodes) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reason to remove this assertion?

parent_node = node_with_reduced_parents.parent_nodes[0]
parent_select_node = parent_node.as_select_node
assert parent_select_node
from_source_node = node_with_reduced_parents.parent_nodes[0]
from_source_select_node = from_source_node.as_select_node
assert (
from_source_select_node is not None
), f"{from_source_select_node=} should be set as `_current_node_can_be_reduced()` returned True"

# At this point, the query should look similar to
#
Expand All @@ -631,7 +635,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# The ORDER BY in the parent doesn't matter since the order by in this node will "overwrite" the order in the
# parent as long as the parent has no limits.
column_replacements = SqlRewritingSubQueryReducerVisitor._get_column_replacements(
parent_node=parent_select_node,
parent_node=from_source_select_node,
parent_node_alias=node.from_source_alias,
)
new_order_bys: List[SqlOrderByDescription] = []
Expand Down Expand Up @@ -671,12 +675,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# The limit should be the min of this SELECT limit and the parent SELECT limit.
new_limit: Optional[int] = node_with_reduced_parents.limit
if new_limit is None:
new_limit = parent_select_node.limit
elif parent_select_node.limit is not None:
new_limit = min(new_limit, parent_select_node.limit)
new_limit = from_source_select_node.limit
elif from_source_select_node.limit is not None:
new_limit = min(new_limit, from_source_select_node.limit)

new_group_bys: Tuple[SqlSelectColumn, ...] = ()
if node.group_bys and parent_select_node.group_bys:
if node.group_bys and from_source_select_node.group_bys:
raise RuntimeError(
"Attempting to reduce sub-queries when this and the parent have GROUP BYs. This should have been "
"prevent by _should_reduce()"
Expand All @@ -685,26 +689,26 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
new_group_bys = SqlRewritingSubQueryReducerVisitor._rewrite_select_columns(
old_select_columns=node.group_bys, column_replacements=column_replacements
)
elif parent_select_node.group_bys:
new_group_bys = parent_select_node.group_bys
elif from_source_select_node.group_bys:
new_group_bys = from_source_select_node.group_bys

return SqlSelectStatementNode.create(
description="\n".join([parent_select_node.description, node_with_reduced_parents.description]),
description="\n".join([from_source_select_node.description, node_with_reduced_parents.description]),
select_columns=SqlRewritingSubQueryReducerVisitor._rewrite_select_columns(
old_select_columns=node.select_columns, column_replacements=column_replacements
),
from_source=parent_select_node.from_source,
from_source_alias=parent_select_node.from_source_alias,
join_descs=parent_select_node.join_descs,
from_source=from_source_select_node.from_source,
from_source_alias=from_source_select_node.from_source_alias,
join_descs=from_source_select_node.join_descs,
group_bys=new_group_bys,
order_bys=tuple(new_order_bys),
where=SqlRewritingSubQueryReducerVisitor._rewrite_where(
column_replacements=column_replacements,
node_where=node.where,
parent_node_where=parent_select_node.where,
parent_node_where=from_source_select_node.where,
),
limit=new_limit,
distinct=parent_select_node.distinct,
distinct=from_source_select_node.distinct,
)

def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode: # noqa: D102
Expand Down