diff --git a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py index 2f23e3aa2e..3aec625a22 100644 --- a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py @@ -154,6 +154,16 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension aggregation_state=self.aggregation_state, ) + @property + def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102 + return TimeDimensionSpec( + element_name=self.element_name, + entity_links=self.entity_links, + time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity), + date_part=self.date_part, + aggregation_state=self.aggregation_state, + ) + def with_grain_and_date_part( # noqa: D102 self, time_granularity: ExpandedTimeGranularity, date_part: Optional[DatePart] ) -> TimeDimensionSpec: diff --git a/metricflow/dataflow/nodes/join_to_custom_granularity.py b/metricflow/dataflow/nodes/join_to_custom_granularity.py index 14e5008c0e..ada3447fc0 100644 --- a/metricflow/dataflow/nodes/join_to_custom_granularity.py +++ b/metricflow/dataflow/nodes/join_to_custom_granularity.py @@ -14,15 +14,31 @@ @dataclass(frozen=True) class JoinToCustomGranularityNode(DataflowPlanNode, ABC): - """Join parent dataset to time spine dataset to convert time dimension to a custom granularity.""" + """Join parent dataset to time spine dataset to convert time dimension to a custom granularity. + + Args: + time_dimension_spec: The time dimension spec with a custom granularity that will be satisfied by this node. + include_base_grain: Bool that indicates if a spec with the custom granularity's base grain + should be included in the node's output dataset. This is needed when the same time dimension is requested + twice in one query, with both a custom grain and that custom grain's base grain. + """ time_dimension_spec: TimeDimensionSpec + include_base_grain: bool + + def __post_init__(self) -> None: # noqa: D105 + assert ( + self.time_dimension_spec.time_granularity.is_custom_granularity + ), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity." + f" Instead, found {self.time_dimension_spec.time_granularity.name}. This indicates internal misconfiguration." @staticmethod def create( # noqa: D102 - parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec + parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec, include_base_grain: bool ) -> JoinToCustomGranularityNode: - return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec) + return JoinToCustomGranularityNode( + parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec, include_base_grain=include_base_grain + ) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -39,6 +55,7 @@ def description(self) -> str: # noqa: D102 def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return tuple(super().displayed_properties) + ( DisplayedProperty("time_dimension_spec", self.time_dimension_spec), + DisplayedProperty("include_base_grain", self.include_base_grain), ) @property @@ -46,12 +63,18 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102 return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 - return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec + return ( + isinstance(other_node, self.__class__) + and other_node.time_dimension_spec == self.time_dimension_spec + and other_node.include_base_grain == self.include_base_grain + ) def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] ) -> JoinToCustomGranularityNode: assert len(new_parent_nodes) == 1, "JoinToCustomGranularity accepts exactly one parent node." return JoinToCustomGranularityNode.create( - parent_node=new_parent_nodes[0], time_dimension_spec=self.time_dimension_spec + parent_node=new_parent_nodes[0], + time_dimension_spec=self.time_dimension_spec, + include_base_grain=self.include_base_grain, ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 929c4f61b1..0623fc49ef 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -199,6 +199,9 @@ def __init__( self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources( semantic_manifest_lookup.semantic_manifest ) + self._custom_granularity_time_spine_sources = TimeSpineSource.build_custom_time_spine_sources( + tuple(self._time_spine_sources.values()) + ) @property def column_association_resolver(self) -> ColumnAssociationResolver: # noqa: D102 @@ -262,9 +265,10 @@ def _make_time_spine_data_set( column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name # If the requested granularity is the same as the granularity of the spine, do a direct select. # TODO: also handle date part. - # TODO: [custom granularity] add support for custom granularities to make_time_spine_data_set agg_time_grain = agg_time_dimension_instance.spec.time_granularity - assert not agg_time_grain.is_custom_granularity, "Custom time granularities are not yet supported!" + assert ( + not agg_time_grain.is_custom_granularity + ), "Custom time granularities are not yet supported for all queries." if agg_time_grain.base_granularity == time_spine_source.base_granularity: select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),) # If any columns have a different granularity, apply a DATE_TRUNC() and aggregate via group_by. @@ -306,6 +310,7 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: instance_set=node.data_set.instance_set, ) + # TODO: write tests for custom granularities that hit this node def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: """Generate time range join SQL.""" table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() @@ -1229,6 +1234,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ), ) + # TODO: write tests for custom granularities that hit this node def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() @@ -1252,10 +1258,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet agg_time_dimension_instances.append(instance) # Choose the instance with the smallest standard granularity available. - # TODO: [custom granularity] Update to account for custom granularity instances assert all( [not instance.spec.time_granularity.is_custom_granularity for instance in agg_time_dimension_instances] - ), "Custom granularities are not yet supported!" + ), "Custom granularities are not yet supported for all queries." agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) assert len(agg_time_dimension_instances) > 0, ( "Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been " @@ -1341,9 +1346,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet # Add requested granularities (if different from time_spine) and date_parts to time spine column. for time_dimension_instance in time_dimensions_to_select_from_time_spine: time_dimension_spec = time_dimension_instance.spec - - # TODO: this will break when we start supporting smaller grain than DAY unless the time spine table is - # updated to use the smallest available grain. if ( time_dimension_spec.time_granularity.base_granularity.to_int() < original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int() @@ -1408,8 +1410,108 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet ), ) + def _get_time_spine_for_custom_granularity(self, custom_granularity: str) -> TimeSpineSource: + time_spine_source = self._custom_granularity_time_spine_sources.get(custom_granularity) + assert time_spine_source, ( + f"Custom granularity {custom_granularity} does not not exist in time spine sources. " + f"Available custom granularities: {list(self._custom_granularity_time_spine_sources.keys())}" + ) + return time_spine_source + + def _get_custom_granularity_column_name(self, custom_granularity_name: str) -> str: + time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name) + for custom_granularity in time_spine_source.custom_granularities: + print(custom_granularity) + if custom_granularity.name == custom_granularity_name: + return custom_granularity.column_name if custom_granularity.column_name else custom_granularity.name + + raise RuntimeError( + f"Custom granularity {custom_granularity} not found. This indicates internal misconfiguration." + ) + def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: # noqa: D102 - raise NotImplementedError # TODO in later commit + parent_data_set = node.parent_node.accept(self) + + # New dataset will be joined to parent dataset without a subquery, so use the same FROM alias as the parent node. + parent_alias = parent_data_set.checked_sql_select_node.from_source_alias + parent_time_dimension_instance: Optional[TimeDimensionInstance] = None + for instance in parent_data_set.instance_set.time_dimension_instances: + if instance.spec == node.time_dimension_spec.with_base_grain: + parent_time_dimension_instance = instance + break + assert parent_time_dimension_instance, ( + "JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. " + "This indicates internal misconfiguration." + ) + + # Build join expression. + time_spine_alias = self._next_unique_table_alias() + custom_granularity_name = node.time_dimension_spec.time_granularity.name + time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name) + left_expr_for_join: SqlExpressionNode = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_alias, column_name=parent_time_dimension_instance.associated_column.column_name + ) + if parent_time_dimension_instance.spec.time_granularity.base_granularity != time_spine_source.base_granularity: + # If needed, apply DATE_TRUNC to parent column match the time spine spine that's column being joined to. + left_expr_for_join = SqlDateTruncExpression.create( + time_granularity=time_spine_source.base_granularity, arg=left_expr_for_join + ) + join_description = SqlJoinDescription( + right_source=SqlTableNode.create(sql_table=time_spine_source.spine_table), + right_source_alias=time_spine_alias, + on_condition=SqlComparisonExpression.create( + left_expr=left_expr_for_join, + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, column_name=time_spine_source.base_column + ), + ), + join_type=SqlJoinType.LEFT_OUTER, + ) + + # Remove base grain from parent dataset, unless that grain was also requested (in addition to the custom grain). + parent_instance_set = parent_data_set.instance_set + parent_select_columns = parent_data_set.checked_sql_select_node.select_columns + if not node.include_base_grain: + parent_instance_set = parent_instance_set.transform( + FilterElements( + exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,)) + ) + ) + parent_select_columns = tuple( + [ + column + for column in parent_select_columns + if column.column_alias != parent_time_dimension_instance.associated_column.column_name + ] + ) + + # Build output time spine instances and columns. + time_spine_instance = TimeDimensionInstance( + defined_from=parent_time_dimension_instance.defined_from, + associated_columns=(self._column_association_resolver.resolve_spec(node.time_dimension_spec),), + spec=node.time_dimension_spec, + ) + time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,)) + time_spine_select_columns = ( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=time_spine_alias, + column_name=self._get_custom_granularity_column_name(custom_granularity_name), + ), + column_alias=time_spine_instance.associated_column.column_name, + ), + ) + return SqlDataSet( + instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]), + sql_select_node=SqlSelectStatementNode.create( + description=node.description + "\n" + parent_data_set.checked_sql_select_node.description, + select_columns=parent_select_columns + time_spine_select_columns, + from_source=parent_data_set.checked_sql_select_node.from_source, + from_source_alias=parent_alias, + join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,), + ), + ) def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self)