From d01cbeac247c8166b26a2dbf94d727d6f1df57a4 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Fri, 1 Nov 2024 11:46:39 -0700 Subject: [PATCH] /* PR_START p--cte 06 */ Split column pruner into two phases. Currently, the column pruner checks the columns that are needed in each `SELECT` statement and generates the pruned SQL in a single pass. For better readability and easier modification, this splits the column pruner into two phases. First, the SQL nodes are traversed to figure out which columns are required and which can be pruned. Then, the SQL nodes are rewritten with the pruned columns. --- metricflow/sql/optimizer/column_pruner.py | 196 ++++------------- .../sql/optimizer/tag_column_aliases.py | 96 +++++++++ .../optimizer/tag_required_column_aliases.py | 200 ++++++++++++++++++ metricflow/sql/sql_plan.py | 9 + .../sql/optimizer/test_column_pruner.py | 10 + 5 files changed, 359 insertions(+), 152 deletions(-) create mode 100644 metricflow/sql/optimizer/tag_column_aliases.py create mode 100644 metricflow/sql/optimizer/tag_required_column_aliases.py diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 79e6d8250a..dde57a12f1 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -1,22 +1,18 @@ from __future__ import annotations import logging -from collections import defaultdict -from typing import Dict, List, Set, Tuple +from typing import FrozenSet, Mapping from typing_extensions import override from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer -from metricflow.sql.sql_exprs import ( - SqlExpressionTreeLineage, -) +from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet +from metricflow.sql.optimizer.tag_required_column_aliases import SqlTagRequiredColumnAliasesVisitor from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, - SqlJoinDescription, SqlQueryPlanNode, SqlQueryPlanNodeVisitor, - SqlSelectColumn, SqlSelectQueryFromClauseNode, SqlSelectStatementNode, SqlTableNode, @@ -28,171 +24,55 @@ class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]): """Removes unnecessary columns from SELECT statements in the SQL query plan. - As the visitor traverses up to the parents, it pushes the list of required columns and rewrites the parent nodes. + This requires a set of tagged column aliases that should be kept for each SQL node. """ def __init__( self, - required_column_aliases: Set[str], + required_alias_mapping: Mapping[SqlQueryPlanNode, FrozenSet[str]], ) -> None: """Constructor. Args: - required_column_aliases: the columns aliases that should not be pruned from the SELECT statements that this - visits. - """ - self._required_column_aliases = required_column_aliases - - def _search_for_expressions( - self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...] - ) -> SqlExpressionTreeLineage: - """Returns the expressions used in the immediate select statement. - - i.e. this does not return expressions used in sub-queries. pruned_select_columns needs to be passed in since the - node may have the select columns pruned. + required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node. """ - all_expr_search_results: List[SqlExpressionTreeLineage] = [] - - for select_column in pruned_select_columns: - all_expr_search_results.append(select_column.expr.lineage) - - for join_description in select_node.join_descs: - if join_description.on_condition: - all_expr_search_results.append(join_description.on_condition.lineage) - - for group_by in select_node.group_bys: - all_expr_search_results.append(group_by.expr.lineage) - - for order_by in select_node.order_bys: - all_expr_search_results.append(order_by.expr.lineage) - - if select_node.where: - all_expr_search_results.append(select_node.where.lineage) - - return SqlExpressionTreeLineage.combine(all_expr_search_results) - - def _prune_columns_from_grandparents( - self, node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...] - ) -> SqlSelectStatementNode: - """Assume that you need all columns from the parent and prune the grandparents.""" - pruned_from_source: SqlQueryPlanNode - if node.from_source.as_select_node: - from_visitor = SqlColumnPrunerVisitor( - required_column_aliases={x.column_alias for x in node.from_source.as_select_node.select_columns} - ) - pruned_from_source = node.from_source.as_select_node.accept(from_visitor) - else: - pruned_from_source = node.from_source - pruned_join_descriptions: List[SqlJoinDescription] = [] - for join_description in node.join_descs: - right_source_as_select_node = join_description.right_source.as_select_node - if right_source_as_select_node: - right_source_visitor = SqlColumnPrunerVisitor( - required_column_aliases={x.column_alias for x in right_source_as_select_node.select_columns} - ) - pruned_join_descriptions.append( - SqlJoinDescription( - right_source=join_description.right_source.accept(right_source_visitor), - right_source_alias=join_description.right_source_alias, - on_condition=join_description.on_condition, - join_type=join_description.join_type, - ) - ) - else: - pruned_join_descriptions.append(join_description) - - return SqlSelectStatementNode.create( - description=node.description, - select_columns=pruned_select_columns, - from_source=pruned_from_source, - from_source_alias=node.from_source_alias, - join_descs=tuple(pruned_join_descriptions), - group_bys=node.group_bys, - order_bys=node.order_bys, - where=node.where, - limit=node.limit, - distinct=node.distinct, - ) - - @override - def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: - raise NotImplementedError + self._required_alias_mapping = required_alias_mapping def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 # Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't # need them. However, keep columns that are in group bys because that changes the meaning of the query. # Similarly, if this node is a distinct select node, keep all columns as it may return a different result set. + required_column_aliases = self._required_alias_mapping.get(node) + if required_column_aliases is None: + logger.error( + f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version " + f"as it should be valid SQL, but this is a bug and should be investigated." + ) + return node + + if len(required_column_aliases) == 0: + logger.error( + f"Got no required column aliases for {node}. Returning the non-pruned version as it should be valid " + f"SQL, but this is a bug and should be investigated." + ) + return node + pruned_select_columns = tuple( select_column for select_column in node.select_columns - if select_column.column_alias in self._required_column_aliases - or select_column in node.group_bys - or node.distinct + if select_column.column_alias in required_column_aliases ) - # TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any - # SqlExpressionNode. - - if len(pruned_select_columns) == 0: - raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.") - - # Based on the expressions in this select statement, figure out what column aliases are needed in the sources of - # this query (i.e. tables or sub-queries in the FROM or JOIN clauses). - exprs_used_in_this_node = self._search_for_expressions(node, pruned_select_columns) - - # If any of the string expressions don't have context on what columns are used in the expression, then it's - # impossible to know what columns can be pruned from the parent sources. So return a SELECT statement that - # leaves the parent sources untouched. Columns from the grandparents can be pruned based on the parent node - # though. - if any([string_expr.used_columns is None for string_expr in exprs_used_in_this_node.string_exprs]): - return self._prune_columns_from_grandparents(node, pruned_select_columns) - - # Create a mapping from the source alias to the column aliases needed from the corresponding source. - source_alias_to_required_column_alias: Dict[str, Set[str]] = defaultdict(set) - for column_reference_expr in exprs_used_in_this_node.column_reference_exprs: - column_reference = column_reference_expr.col_ref - source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name) - - # For all string columns, assume that they are needed from all sources since we don't have a table alias - # in SqlStringExpression.used_columns - for string_expr in exprs_used_in_this_node.string_exprs: - if string_expr.used_columns: - for column_alias in string_expr.used_columns: - source_alias_to_required_column_alias[node.from_source_alias].add(column_alias) - for join_description in node.join_descs: - source_alias_to_required_column_alias[join_description.right_source_alias].add(column_alias) - # Same with unqualified column references. - for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs: - column_alias = unqualified_column_reference_expr.column_alias - source_alias_to_required_column_alias[node.from_source_alias].add(column_alias) - for join_description in node.join_descs: - source_alias_to_required_column_alias[join_description.right_source_alias].add(column_alias) - - # Once we know which column aliases are required from which source aliases, replace the sources with new SELECT - # statements. - from_source_pruner = SqlColumnPrunerVisitor( - required_column_aliases=source_alias_to_required_column_alias[node.from_source_alias] - ) - pruned_from_source = node.from_source.accept(from_source_pruner) - pruned_join_descriptions: List[SqlJoinDescription] = [] - for join_description in node.join_descs: - join_source_pruner = SqlColumnPrunerVisitor( - required_column_aliases=source_alias_to_required_column_alias[join_description.right_source_alias] - ) - pruned_join_descriptions.append( - SqlJoinDescription( - right_source=join_description.right_source.accept(join_source_pruner), - right_source_alias=join_description.right_source_alias, - on_condition=join_description.on_condition, - join_type=join_description.join_type, - ) - ) return SqlSelectStatementNode.create( description=node.description, - select_columns=tuple(pruned_select_columns), - from_source=pruned_from_source, + select_columns=pruned_select_columns, + from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - join_descs=tuple(pruned_join_descriptions), + # TODO: Handle CTEs. + cte_sources=(), + join_descs=tuple( + join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs + ), group_bys=node.group_bys, order_bys=node.order_bys, where=node.where, @@ -214,6 +94,10 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan parent_node=node.parent_node.accept(self), ) + @override + def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: + raise NotImplementedError + class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer): """Removes unnecessary columns in the SELECT clauses.""" @@ -223,8 +107,16 @@ def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102 if not node.as_select_node: return node - pruning_visitor = SqlColumnPrunerVisitor( - required_column_aliases={x.column_alias for x in node.as_select_node.select_columns} + # Figure out which columns in which nodes are required. + tagged_column_alias_set = TaggedColumnAliasSet() + tagged_column_alias_set.tag_all_aliases_in_node(node.as_select_node) + tag_required_column_alias_visitor = SqlTagRequiredColumnAliasesVisitor( + tagged_column_alias_set=tagged_column_alias_set, ) + node.accept(tag_required_column_alias_visitor) + # Re-write the query, pruning columns in the SELECT that are not needed. + pruning_visitor = SqlColumnPrunerVisitor( + required_alias_mapping=tagged_column_alias_set.get_mapping(), + ) return node.accept(pruning_visitor) diff --git a/metricflow/sql/optimizer/tag_column_aliases.py b/metricflow/sql/optimizer/tag_column_aliases.py new file mode 100644 index 0000000000..77b7ee41e6 --- /dev/null +++ b/metricflow/sql/optimizer/tag_column_aliases.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import Dict, FrozenSet, Iterable, Mapping, Set + +from typing_extensions import override + +from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, + SqlCteNode, + SqlQueryPlanNode, + SqlQueryPlanNodeVisitor, + SqlSelectQueryFromClauseNode, + SqlSelectStatementNode, + SqlTableNode, +) + +logger = logging.getLogger(__name__) + + +class TaggedColumnAliasSet: + """Keep track of column aliases in SELECT statements that have been tagged. + + The main use case for this class is to keep track of column aliases / columns that are required so that unnecessary + columns can be pruned. + + For example, in this query: + + SELECT source_0.col_0 AS col_0 + FROM ( + SELECT + example_table.col_0 + example_table.col_1 + FROM example_table + ) source_0 + + this class can be used to tag `example_table.col_0` but not tag `example_table.col_1` since it's not needed for the + query to run correctly. + """ + + def __init__(self) -> None: # noqa: D107 + self._node_to_tagged_aliases: Dict[SqlQueryPlanNode, Set[str]] = defaultdict(set) + + def get_tagged_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]: + """Return the given tagged column aliases associated with the given SQL node.""" + return frozenset(self._node_to_tagged_aliases[node]) + + def tag_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102 + return self._node_to_tagged_aliases[node].add(column_alias) + + def tag_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102 + self._node_to_tagged_aliases[node].update(column_aliases) + + def tag_all_aliases_in_node(self, node: SqlQueryPlanNode) -> None: + """Convenience method that tags all column aliases in the given SQL node, where appropriate.""" + node.accept(_TagAllColumnAliasesInNodeVisitor(self)) + + def get_mapping(self) -> Mapping[SqlQueryPlanNode, FrozenSet[str]]: + """Return mapping view that associates a given SQL node with the tagged column aliases in that node.""" + return {node: frozenset(tagged_aliases) for node, tagged_aliases in self._node_to_tagged_aliases.items()} + + +class _TagAllColumnAliasesInNodeVisitor(SqlQueryPlanNodeVisitor[None]): + """Visitor to help implement `TaggedColumnAliasSet.tag_all_aliases_in_node`.""" + + def __init__(self, required_column_alias_collector: TaggedColumnAliasSet) -> None: + self._required_column_alias_collector = required_column_alias_collector + + @override + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: + for select_column in node.select_columns: + self._required_column_alias_collector.tag_alias( + node=node, + column_alias=select_column.column_alias, + ) + + @override + def visit_table_node(self, node: SqlTableNode) -> None: + """Columns in a SQL table are not represented.""" + pass + + @override + def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None: + """Columns in an arbitrary SQL query are not represented.""" + pass + + @override + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None: + for parent_node in node.parent_nodes: + parent_node.accept(self) + + @override + def visit_cte_node(self, node: SqlCteNode) -> None: + for parent_node in node.parent_nodes: + parent_node.accept(self) diff --git a/metricflow/sql/optimizer/tag_required_column_aliases.py b/metricflow/sql/optimizer/tag_required_column_aliases.py new file mode 100644 index 0000000000..dcddc5f0c7 --- /dev/null +++ b/metricflow/sql/optimizer/tag_required_column_aliases.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import Dict, List, Set, Tuple + +from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat +from typing_extensions import override + +from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet +from metricflow.sql.sql_exprs import SqlExpressionTreeLineage +from metricflow.sql.sql_plan import ( + SqlCreateTableAsNode, + SqlCteNode, + SqlQueryPlanNode, + SqlQueryPlanNodeVisitor, + SqlSelectColumn, + SqlSelectQueryFromClauseNode, + SqlSelectStatementNode, + SqlTableNode, +) + +logger = logging.getLogger(__name__) + + +class SqlTagRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]): + """To aid column pruning, traverse the SQL-query representation DAG and tag all column aliases that are required. + + For example, for the query: + + SELECT source_0.col_0 AS col_0_renamed + FROM ( + SELECT + example_table.col_0 + example_table.col_1 + FROM example_table_0 + ) source_0 + + The top-level SQL node would have the column alias `col_0_renamed` tagged, and the SQL node associated with + `source_0` would have `col_0` tagged. Once tagged, the information can be used to prune the columns in the SELECT: + + SELECT source_0.col_0 AS col_0_renamed + FROM ( + SELECT + example_table.col_0 + FROM example_table_0 + ) source_0 + """ + + def __init__(self, tagged_column_alias_set: TaggedColumnAliasSet) -> None: + """Initializer. + + Args: + tagged_column_alias_set: Stores the set of columns that are tagged. This will be updated as the visitor + traverses the SQL-query representation DAG. + """ + self._column_alias_tagger = tagged_column_alias_set + + def _search_for_expressions( + self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...] + ) -> SqlExpressionTreeLineage: + """Returns the expressions used in the immediate select statement. + + i.e. this does not return expressions used in sub-queries. pruned_select_columns needs to be passed in since the + node may have the select columns pruned. + """ + all_expr_search_results: List[SqlExpressionTreeLineage] = [] + + for select_column in pruned_select_columns: + all_expr_search_results.append(select_column.expr.lineage) + + for join_description in select_node.join_descs: + if join_description.on_condition: + all_expr_search_results.append(join_description.on_condition.lineage) + + for group_by in select_node.group_bys: + all_expr_search_results.append(group_by.expr.lineage) + + for order_by in select_node.order_bys: + all_expr_search_results.append(order_by.expr.lineage) + + if select_node.where: + all_expr_search_results.append(select_node.where.lineage) + + return SqlExpressionTreeLineage.combine(all_expr_search_results) + + @override + def visit_cte_node(self, node: SqlCteNode) -> None: + raise NotImplementedError + + def _visit_parents(self, node: SqlQueryPlanNode) -> None: + """Default recursive handler to visit the parents of the given node.""" + for parent_node in node.parent_nodes: + parent_node.accept(self) + return + + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # noqa: D102 + # Based on column aliases that are tagged in this SELECT statement, tag corresponding column aliases in + # parent nodes. + + initial_required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node) + + # If this SELECT statement uses DISTINCT, all columns are required as removing them would change the meaning of + # the query. + updated_required_column_aliases_in_this_node = set(initial_required_column_aliases_in_this_node) + if node.distinct: + updated_required_column_aliases_in_this_node.update( + {select_column.column_alias for select_column in node.select_columns} + ) + + # Any columns in the group by also need to be kept to have a correct query. + updated_required_column_aliases_in_this_node.update( + {group_by_select_column.column_alias for group_by_select_column in node.group_bys} + ) + logger.debug( + LazyFormat( + "Tagging column aliases in parent nodes given what's required in this node", + this_node=node, + initial_required_column_aliases_in_this_node=list(initial_required_column_aliases_in_this_node), + updated_required_column_aliases_in_this_node=list(updated_required_column_aliases_in_this_node), + ) + ) + # Since additional select columns could have been selected due to DISTINCT or GROUP BY, re-tag. + self._column_alias_tagger.tag_aliases(node, updated_required_column_aliases_in_this_node) + + required_select_columns_in_this_node = tuple( + select_column + for select_column in node.select_columns + if select_column.column_alias in updated_required_column_aliases_in_this_node + ) + + # TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any + # SqlExpressionNode. + + if len(required_select_columns_in_this_node) == 0: + raise RuntimeError( + "No columns are required in this node - this indicates a bug in this collector or in the inputs." + ) + + # Based on the expressions in this select statement, figure out what column aliases are needed in the sources of + # this query (i.e. tables or sub-queries in the FROM or JOIN clauses). + exprs_used_in_this_node = self._search_for_expressions(node, required_select_columns_in_this_node) + + # If any of the string expressions don't have context on what columns are used in the expression, then it's + # impossible to know what columns can be pruned from the parent sources. Tag all columns in parents as required. + if any([string_expr.used_columns is None for string_expr in exprs_used_in_this_node.string_exprs]): + for parent_node in node.parent_nodes: + self._column_alias_tagger.tag_all_aliases_in_node(parent_node) + self._visit_parents(node) + return + + # Create a mapping from the source alias to the column aliases needed from the corresponding source. + source_alias_to_required_column_alias: Dict[str, Set[str]] = defaultdict(set) + for column_reference_expr in exprs_used_in_this_node.column_reference_exprs: + column_reference = column_reference_expr.col_ref + source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name) + + # Appropriately tag the columns required in the parent nodes. + if node.from_source_alias in source_alias_to_required_column_alias: + aliases_required_in_parent = source_alias_to_required_column_alias[node.from_source_alias] + self._column_alias_tagger.tag_aliases(node=node.from_source, column_aliases=aliases_required_in_parent) + for join_desc in node.join_descs: + if join_desc.right_source_alias in source_alias_to_required_column_alias: + aliases_required_in_parent = source_alias_to_required_column_alias[join_desc.right_source_alias] + self._column_alias_tagger.tag_aliases( + node=join_desc.right_source, column_aliases=aliases_required_in_parent + ) + # TODO: Handle CTEs parent nodes. + + # For all string columns, assume that they are needed from all sources since we don't have a table alias + # in SqlStringExpression.used_columns + for string_expr in exprs_used_in_this_node.string_exprs: + if string_expr.used_columns: + for column_alias in string_expr.used_columns: + for parent_node in node.parent_nodes: + self._column_alias_tagger.tag_alias(parent_node, column_alias) + + # Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say + # it's required from all parents. + # An unqualified column reference expression is like `SELECT col_0` whereas a qualified column reference + # expression is like `SELECT table_0.col_0`. + for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs: + column_alias = unqualified_column_reference_expr.column_alias + for parent_node in node.parent_nodes: + self._column_alias_tagger.tag_alias(parent_node, column_alias) + + # Visit recursively. + self._visit_parents(node) + return + + def visit_table_node(self, node: SqlTableNode) -> None: + """There are no SELECT columns in this node, so pruning cannot apply.""" + return + + def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None: + """Pruning cannot be done here since this is an arbitrary user-provided SQL query.""" + return + + def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None: # noqa: D102 + return self._visit_parents(node) diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 52a7390429..8f82501349 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -98,6 +98,15 @@ class SqlJoinDescription: join_type: SqlJoinType on_condition: Optional[SqlExpressionNode] = None + def with_right_source(self, new_right_source: SqlQueryPlanNode) -> SqlJoinDescription: + """Return a copy of this but with the right source replaced.""" + return SqlJoinDescription( + right_source=new_right_source, + right_source_alias=self.right_source_alias, + join_type=self.join_type, + on_condition=self.on_condition, + ) + @dataclass(frozen=True) class SqlOrderByDescription: # noqa: D101 diff --git a/tests_metricflow/sql/optimizer/test_column_pruner.py b/tests_metricflow/sql/optimizer/test_column_pruner.py index 53b4bc3c51..01760be80a 100644 --- a/tests_metricflow/sql/optimizer/test_column_pruner.py +++ b/tests_metricflow/sql/optimizer/test_column_pruner.py @@ -1,7 +1,10 @@ from __future__ import annotations +import logging + import pytest from _pytest.fixtures import FixtureRequest +from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration @@ -25,6 +28,8 @@ ) from tests_metricflow.sql.compare_sql_plan import assert_default_rendered_sql_equal +logger = logging.getLogger(__name__) + @pytest.fixture def column_pruner() -> SqlColumnPrunerOptimizer: # noqa: D103 @@ -206,6 +211,8 @@ def test_no_pruning( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests a case where no pruning should occur.""" + logger.debug(LazyFormat("Pruning select statement", base_select_statement=base_select_statement.structure_text())) + assert_default_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, @@ -900,6 +907,9 @@ def test_prune_distinct_select( ), from_source_alias="b", ) + + logger.debug(LazyFormat("Pruning select statement", select_statement=select_node.structure_text())) + assert_default_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration,