Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split column pruner into two phases #1501

Open
wants to merge 1 commit into
base: p--cte--05
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 44 additions & 152 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Dict, List, Set, Tuple
from typing import FrozenSet, Mapping

from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
SqlExpressionTreeLineage,
)
from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet
from metricflow.sql.optimizer.tag_required_column_aliases import SqlTagRequiredColumnAliasesVisitor
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
Expand All @@ -28,171 +24,55 @@
class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):
"""Removes unnecessary columns from SELECT statements in the SQL query plan.

As the visitor traverses up to the parents, it pushes the list of required columns and rewrites the parent nodes.
This requires a set of tagged column aliases that should be kept for each SQL node.
"""

def __init__(
self,
required_column_aliases: Set[str],
required_alias_mapping: Mapping[SqlQueryPlanNode, FrozenSet[str]],
) -> None:
"""Constructor.

Args:
required_column_aliases: the columns aliases that should not be pruned from the SELECT statements that this
visits.
"""
self._required_column_aliases = required_column_aliases

def _search_for_expressions(
self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
) -> SqlExpressionTreeLineage:
"""Returns the expressions used in the immediate select statement.

i.e. this does not return expressions used in sub-queries. pruned_select_columns needs to be passed in since the
node may have the select columns pruned.
required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node.
"""
all_expr_search_results: List[SqlExpressionTreeLineage] = []

for select_column in pruned_select_columns:
all_expr_search_results.append(select_column.expr.lineage)

for join_description in select_node.join_descs:
if join_description.on_condition:
all_expr_search_results.append(join_description.on_condition.lineage)

for group_by in select_node.group_bys:
all_expr_search_results.append(group_by.expr.lineage)

for order_by in select_node.order_bys:
all_expr_search_results.append(order_by.expr.lineage)

if select_node.where:
all_expr_search_results.append(select_node.where.lineage)

return SqlExpressionTreeLineage.combine(all_expr_search_results)

def _prune_columns_from_grandparents(
self, node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
) -> SqlSelectStatementNode:
"""Assume that you need all columns from the parent and prune the grandparents."""
pruned_from_source: SqlQueryPlanNode
if node.from_source.as_select_node:
from_visitor = SqlColumnPrunerVisitor(
required_column_aliases={x.column_alias for x in node.from_source.as_select_node.select_columns}
)
pruned_from_source = node.from_source.as_select_node.accept(from_visitor)
else:
pruned_from_source = node.from_source
pruned_join_descriptions: List[SqlJoinDescription] = []
for join_description in node.join_descs:
right_source_as_select_node = join_description.right_source.as_select_node
if right_source_as_select_node:
right_source_visitor = SqlColumnPrunerVisitor(
required_column_aliases={x.column_alias for x in right_source_as_select_node.select_columns}
)
pruned_join_descriptions.append(
SqlJoinDescription(
right_source=join_description.right_source.accept(right_source_visitor),
right_source_alias=join_description.right_source_alias,
on_condition=join_description.on_condition,
join_type=join_description.join_type,
)
)
else:
pruned_join_descriptions.append(join_description)

return SqlSelectStatementNode.create(
description=node.description,
select_columns=pruned_select_columns,
from_source=pruned_from_source,
from_source_alias=node.from_source_alias,
join_descs=tuple(pruned_join_descriptions),
group_bys=node.group_bys,
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
self._required_alias_mapping = required_alias_mapping

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
required_column_aliases = self._required_alias_mapping.get(node)
if required_column_aliases is None:
logger.error(
f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version "
f"as it should be valid SQL, but this is a bug and should be investigated."
)
return node

if len(required_column_aliases) == 0:
logger.error(
f"Got no required column aliases for {node}. Returning the non-pruned version as it should be valid "
f"SQL, but this is a bug and should be investigated."
)
return node

pruned_select_columns = tuple(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tangential, but I've frequently read this code and found this variable name confusing (pruned_select_columns). We frequently refer to "pruned columns" when we mean the ones that have been removed, but in this case we mean the columns that have been kept. I think the word pruned can technically be used both ways, but it typically is used to refer to what has been removed. Can we change this to a more clear variable name?

select_column
for select_column in node.select_columns
if select_column.column_alias in self._required_column_aliases
or select_column in node.group_bys
or node.distinct
if select_column.column_alias in required_column_aliases
)
# TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any
# SqlExpressionNode.

if len(pruned_select_columns) == 0:
raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.")

# Based on the expressions in this select statement, figure out what column aliases are needed in the sources of
# this query (i.e. tables or sub-queries in the FROM or JOIN clauses).
exprs_used_in_this_node = self._search_for_expressions(node, pruned_select_columns)

# If any of the string expressions don't have context on what columns are used in the expression, then it's
# impossible to know what columns can be pruned from the parent sources. So return a SELECT statement that
# leaves the parent sources untouched. Columns from the grandparents can be pruned based on the parent node
# though.
if any([string_expr.used_columns is None for string_expr in exprs_used_in_this_node.string_exprs]):
return self._prune_columns_from_grandparents(node, pruned_select_columns)

# Create a mapping from the source alias to the column aliases needed from the corresponding source.
source_alias_to_required_column_alias: Dict[str, Set[str]] = defaultdict(set)
for column_reference_expr in exprs_used_in_this_node.column_reference_exprs:
column_reference = column_reference_expr.col_ref
source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name)

# For all string columns, assume that they are needed from all sources since we don't have a table alias
# in SqlStringExpression.used_columns
for string_expr in exprs_used_in_this_node.string_exprs:
if string_expr.used_columns:
for column_alias in string_expr.used_columns:
source_alias_to_required_column_alias[node.from_source_alias].add(column_alias)
for join_description in node.join_descs:
source_alias_to_required_column_alias[join_description.right_source_alias].add(column_alias)
# Same with unqualified column references.
for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs:
column_alias = unqualified_column_reference_expr.column_alias
source_alias_to_required_column_alias[node.from_source_alias].add(column_alias)
for join_description in node.join_descs:
source_alias_to_required_column_alias[join_description.right_source_alias].add(column_alias)

# Once we know which column aliases are required from which source aliases, replace the sources with new SELECT
# statements.
from_source_pruner = SqlColumnPrunerVisitor(
required_column_aliases=source_alias_to_required_column_alias[node.from_source_alias]
)
pruned_from_source = node.from_source.accept(from_source_pruner)
pruned_join_descriptions: List[SqlJoinDescription] = []
for join_description in node.join_descs:
join_source_pruner = SqlColumnPrunerVisitor(
required_column_aliases=source_alias_to_required_column_alias[join_description.right_source_alias]
)
pruned_join_descriptions.append(
SqlJoinDescription(
right_source=join_description.right_source.accept(join_source_pruner),
right_source_alias=join_description.right_source_alias,
on_condition=join_description.on_condition,
join_type=join_description.join_type,
)
)

return SqlSelectStatementNode.create(
description=node.description,
select_columns=tuple(pruned_select_columns),
from_source=pruned_from_source,
select_columns=pruned_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
join_descs=tuple(pruned_join_descriptions),
# TODO: Handle CTEs.
cte_sources=(),
join_descs=tuple(
join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs
),
group_bys=node.group_bys,
order_bys=node.order_bys,
where=node.where,
Expand All @@ -214,6 +94,10 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan
parent_node=node.parent_node.accept(self),
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
"""Removes unnecessary columns in the SELECT clauses."""
Expand All @@ -223,8 +107,16 @@ def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102
if not node.as_select_node:
return node

pruning_visitor = SqlColumnPrunerVisitor(
required_column_aliases={x.column_alias for x in node.as_select_node.select_columns}
# Figure out which columns in which nodes are required.
tagged_column_alias_set = TaggedColumnAliasSet()
tagged_column_alias_set.tag_all_aliases_in_node(node.as_select_node)
tag_required_column_alias_visitor = SqlTagRequiredColumnAliasesVisitor(
tagged_column_alias_set=tagged_column_alias_set,
)
node.accept(tag_required_column_alias_visitor)

# Re-write the query, pruning columns in the SELECT that are not needed.
pruning_visitor = SqlColumnPrunerVisitor(
required_alias_mapping=tagged_column_alias_set.get_mapping(),
)
return node.accept(pruning_visitor)
96 changes: 96 additions & 0 deletions metricflow/sql/optimizer/tag_column_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Dict, FrozenSet, Iterable, Mapping, Set

from typing_extensions import override

from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)


class TaggedColumnAliasSet:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it very unintuitive to understand what you meant by "tag" in this whole PR. I would recommend changing that word to something else more clear everywhere it's used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this class specifically - it feels like the name implies a simple dataclass / storage object. I would recommend changing the name to something like ColumnAliasCollector or SqlNodeColumnAliasLinker.

"""Keep track of column aliases in SELECT statements that have been tagged.

The main use case for this class is to keep track of column aliases / columns that are required so that unnecessary
columns can be pruned.

For example, in this query:

SELECT source_0.col_0 AS col_0
FROM (
SELECT
example_table.col_0
example_table.col_1
FROM example_table
) source_0

this class can be used to tag `example_table.col_0` but not tag `example_table.col_1` since it's not needed for the
query to run correctly.
"""

def __init__(self) -> None: # noqa: D107
self._node_to_tagged_aliases: Dict[SqlQueryPlanNode, Set[str]] = defaultdict(set)

def get_tagged_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]:
"""Return the given tagged column aliases associated with the given SQL node."""
return frozenset(self._node_to_tagged_aliases[node])

def tag_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102
return self._node_to_tagged_aliases[node].add(column_alias)

def tag_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102
self._node_to_tagged_aliases[node].update(column_aliases)

def tag_all_aliases_in_node(self, node: SqlQueryPlanNode) -> None:
"""Convenience method that tags all column aliases in the given SQL node, where appropriate."""
node.accept(_TagAllColumnAliasesInNodeVisitor(self))

def get_mapping(self) -> Mapping[SqlQueryPlanNode, FrozenSet[str]]:
"""Return mapping view that associates a given SQL node with the tagged column aliases in that node."""
return {node: frozenset(tagged_aliases) for node, tagged_aliases in self._node_to_tagged_aliases.items()}


class _TagAllColumnAliasesInNodeVisitor(SqlQueryPlanNodeVisitor[None]):
"""Visitor to help implement `TaggedColumnAliasSet.tag_all_aliases_in_node`."""

def __init__(self, required_column_alias_collector: TaggedColumnAliasSet) -> None:
self._required_column_alias_collector = required_column_alias_collector

@override
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
for select_column in node.select_columns:
self._required_column_alias_collector.tag_alias(
node=node,
column_alias=select_column.column_alias,
)

@override
def visit_table_node(self, node: SqlTableNode) -> None:
"""Columns in a SQL table are not represented."""
pass

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
"""Columns in an arbitrary SQL query are not represented."""
pass

@override
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)
Loading
Loading