diff --git a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py index a2792ca510..2aa3ac6d0b 100644 --- a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py +++ b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py @@ -39,6 +39,10 @@ def contains_metric_time(self) -> bool: """Returns true if this set contains a spec referring to metric time at any grain.""" return len(self.metric_time_specs) > 0 + @property + def time_dimension_specs_with_custom_grain(self) -> Tuple[TimeDimensionSpec, ...]: # noqa: D102 + return tuple([spec for spec in self.time_dimension_specs if spec.time_granularity.is_custom_granularity]) + def included_agg_time_dimension_specs_for_metric( self, metric_reference: MetricReference, metric_lookup: MetricLookup ) -> List[TimeDimensionSpec]: diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 22c381a039..b66cd7b128 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -798,13 +798,8 @@ def _build_plan_for_distinct_values( for time_dimension_spec in required_linkable_specs.time_dimension_specs: if time_dimension_spec.time_granularity.is_custom_granularity: - include_base_grain = ( - time_dimension_spec.with_base_grain() in required_linkable_specs.time_dimension_specs - ) output_node = JoinToCustomGranularityNode.create( - parent_node=output_node, - time_dimension_spec=time_dimension_spec, - include_base_grain=include_base_grain, + parent_node=output_node, time_dimension_spec=time_dimension_spec ) if len(query_level_filter_specs) > 0: @@ -1410,6 +1405,12 @@ def __get_required_and_extraneous_linkable_specs( extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe() required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe() + # Custom grains require joining to their base grain, so add base grain to extraneous specs. + base_grain_set = LinkableSpecSet.create_from_specs( + [spec.with_base_grain() for spec in required_linkable_specs.time_dimension_specs_with_custom_grain] + ) + extraneous_linkable_specs = extraneous_linkable_specs.merge(base_grain_set).dedupe() + return required_linkable_specs, extraneous_linkable_specs def _build_aggregated_measure_from_measure_source_node( @@ -1562,7 +1563,12 @@ def _build_aggregated_measure_from_measure_source_node( ) specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge( - InstanceSpecSet.create_from_specs(required_linkable_specs.as_tuple), + InstanceSpecSet.create_from_specs( + [ + spec.with_base_grain() if isinstance(spec, TimeDimensionSpec) else spec + for spec in required_linkable_specs.as_tuple + ] + ), ) after_join_filtered_node = FilterElementsNode.create( @@ -1572,15 +1578,10 @@ def _build_aggregated_measure_from_measure_source_node( else: unaggregated_measure_node = filtered_measure_source_node - for time_dimension_spec in queried_linkable_specs.time_dimension_specs: + for time_dimension_spec in required_linkable_specs.time_dimension_specs: if time_dimension_spec.time_granularity.is_custom_granularity: - include_base_grain = ( - time_dimension_spec.with_base_grain() in required_linkable_specs.time_dimension_specs - ) unaggregated_measure_node = JoinToCustomGranularityNode.create( - parent_node=unaggregated_measure_node, - time_dimension_spec=time_dimension_spec, - include_base_grain=include_base_grain, + parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec ) # If time constraint was previously adjusted for cumulative window or grain, apply original time constraint diff --git a/metricflow/dataflow/nodes/join_to_custom_granularity.py b/metricflow/dataflow/nodes/join_to_custom_granularity.py index ada3447fc0..6f13f6ece4 100644 --- a/metricflow/dataflow/nodes/join_to_custom_granularity.py +++ b/metricflow/dataflow/nodes/join_to_custom_granularity.py @@ -18,15 +18,12 @@ class JoinToCustomGranularityNode(DataflowPlanNode, ABC): Args: time_dimension_spec: The time dimension spec with a custom granularity that will be satisfied by this node. - include_base_grain: Bool that indicates if a spec with the custom granularity's base grain - should be included in the node's output dataset. This is needed when the same time dimension is requested - twice in one query, with both a custom grain and that custom grain's base grain. """ time_dimension_spec: TimeDimensionSpec - include_base_grain: bool def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() assert ( self.time_dimension_spec.time_granularity.is_custom_granularity ), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity." @@ -34,11 +31,9 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 - parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec, include_base_grain: bool + parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec ) -> JoinToCustomGranularityNode: - return JoinToCustomGranularityNode( - parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec, include_base_grain=include_base_grain - ) + return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 @@ -55,7 +50,6 @@ def description(self) -> str: # noqa: D102 def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return tuple(super().displayed_properties) + ( DisplayedProperty("time_dimension_spec", self.time_dimension_spec), - DisplayedProperty("include_base_grain", self.include_base_grain), ) @property @@ -63,11 +57,7 @@ def parent_node(self) -> DataflowPlanNode: # noqa: D102 return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 - return ( - isinstance(other_node, self.__class__) - and other_node.time_dimension_spec == self.time_dimension_spec - and other_node.include_base_grain == self.include_base_grain - ) + return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] @@ -76,5 +66,4 @@ def with_new_parents( # noqa: D102 return JoinToCustomGranularityNode.create( parent_node=new_parent_nodes[0], time_dimension_spec=self.time_dimension_spec, - include_base_grain=self.include_base_grain, ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index d3041bee2c..a2de4cefc7 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -1464,23 +1464,6 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod join_type=SqlJoinType.LEFT_OUTER, ) - # Remove base grain from parent dataset, unless that grain was also requested (in addition to the custom grain). - parent_instance_set = parent_data_set.instance_set - parent_select_columns = parent_data_set.checked_sql_select_node.select_columns - if not node.include_base_grain: - parent_instance_set = parent_instance_set.transform( - FilterElements( - exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,)) - ) - ) - parent_select_columns = tuple( - [ - column - for column in parent_select_columns - if column.column_alias != parent_time_dimension_instance.associated_column.column_name - ] - ) - # Build output time spine instances and columns. time_spine_instance = TimeDimensionInstance( defined_from=parent_time_dimension_instance.defined_from, @@ -1498,10 +1481,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod ), ) return SqlDataSet( - instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]), + instance_set=InstanceSet.merge([time_spine_instance_set, parent_data_set.instance_set]), sql_select_node=SqlSelectStatementNode.create( - description=node.description + "\n" + parent_data_set.checked_sql_select_node.description, - select_columns=parent_select_columns + time_spine_select_columns, + description=parent_data_set.checked_sql_select_node.description + "\n" + node.description, + select_columns=parent_data_set.checked_sql_select_node.select_columns + time_spine_select_columns, from_source=parent_data_set.checked_sql_select_node.from_source, from_source_alias=parent_alias, join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,),