Skip to content

Commit

Permalink
Bug fix: use parent column expr in custom grain join condition in cas…
Browse files Browse the repository at this point in the history
…e column gets pruned
  • Loading branch information
courtneyholcomb committed Sep 24, 2024
1 parent 454411e commit 5ef7eb6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 14 additions & 5 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,23 +1438,32 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
if instance.spec == node.time_dimension_spec.with_base_grain():
parent_time_dimension_instance = instance
break
parent_column: Optional[SqlSelectColumn] = None
assert parent_time_dimension_instance, (
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. "
"This indicates internal misconfiguration."
f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain}; "
f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}"
)
for select_column in parent_data_set.checked_sql_select_node.select_columns:
if select_column.column_alias == parent_time_dimension_instance.associated_column.column_name:
parent_column = select_column
break
assert parent_column, (
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent columns. "
f"This indicates internal misconfiguration. Expected: "
f"{parent_time_dimension_instance.associated_column.column_name}; Got: "
f"{[column.column_alias for column in parent_data_set.checked_sql_select_node.select_columns]}"
)

# Build join expression.
time_spine_alias = self._next_unique_table_alias()
custom_granularity_name = node.time_dimension_spec.time_granularity.name
time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name)
left_expr_for_join: SqlExpressionNode = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_alias, column_name=parent_time_dimension_instance.associated_column.column_name
)
join_description = SqlJoinDescription(
right_source=SqlTableNode.create(sql_table=time_spine_source.spine_table),
right_source_alias=time_spine_alias,
on_condition=SqlComparisonExpression.create(
left_expr=left_expr_for_join,
left_expr=parent_column.expr,
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=time_spine_source.base_column
Expand Down
2 changes: 2 additions & 0 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
or select_column in node.group_bys
or node.distinct
)
# TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any
# SqlExpressionNode.

if len(pruned_select_columns) == 0:
raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.")
Expand Down

0 comments on commit 5ef7eb6

Please sign in to comment.