Skip to content

Commit

Permalink
Filter elements properly
Browse files Browse the repository at this point in the history
This change renders the include_base_grain property useless, so I removed that too.
  • Loading branch information
courtneyholcomb committed Sep 24, 2024
1 parent d5e14d2 commit 3050414
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def contains_metric_time(self) -> bool:
"""Returns true if this set contains a spec referring to metric time at any grain."""
return len(self.metric_time_specs) > 0

@property
def time_dimension_specs_with_custom_grain(self) -> Tuple[TimeDimensionSpec, ...]: # noqa: D102
return tuple([spec for spec in self.time_dimension_specs if spec.time_granularity.is_custom_granularity])

def included_agg_time_dimension_specs_for_metric(
self, metric_reference: MetricReference, metric_lookup: MetricLookup
) -> List[TimeDimensionSpec]:
Expand Down
29 changes: 15 additions & 14 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,13 +798,8 @@ def _build_plan_for_distinct_values(

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
include_base_grain = (
time_dimension_spec.with_base_grain() in required_linkable_specs.time_dimension_specs
)
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node,
time_dimension_spec=time_dimension_spec,
include_base_grain=include_base_grain,
parent_node=output_node, time_dimension_spec=time_dimension_spec
)

if len(query_level_filter_specs) > 0:
Expand Down Expand Up @@ -1410,6 +1405,12 @@ def __get_required_and_extraneous_linkable_specs(
extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe()
required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe()

# Custom grains require joining to their base grain, so add base grain to extraneous specs.
base_grain_set = LinkableSpecSet.create_from_specs(
[spec.with_base_grain() for spec in required_linkable_specs.time_dimension_specs_with_custom_grain]
)
extraneous_linkable_specs = extraneous_linkable_specs.merge(base_grain_set).dedupe()

return required_linkable_specs, extraneous_linkable_specs

def _build_aggregated_measure_from_measure_source_node(
Expand Down Expand Up @@ -1562,7 +1563,12 @@ def _build_aggregated_measure_from_measure_source_node(
)

specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_specs(required_linkable_specs.as_tuple),
InstanceSpecSet.create_from_specs(
[
spec.with_base_grain() if isinstance(spec, TimeDimensionSpec) else spec
for spec in required_linkable_specs.as_tuple
]
),
)

after_join_filtered_node = FilterElementsNode.create(
Expand All @@ -1572,15 +1578,10 @@ def _build_aggregated_measure_from_measure_source_node(
else:
unaggregated_measure_node = filtered_measure_source_node

for time_dimension_spec in queried_linkable_specs.time_dimension_specs:
for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
include_base_grain = (
time_dimension_spec.with_base_grain() in required_linkable_specs.time_dimension_specs
)
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node,
time_dimension_spec=time_dimension_spec,
include_base_grain=include_base_grain,
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
)

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
Expand Down
19 changes: 4 additions & 15 deletions metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,22 @@ class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
Args:
time_dimension_spec: The time dimension spec with a custom granularity that will be satisfied by this node.
include_base_grain: Bool that indicates if a spec with the custom granularity's base grain
should be included in the node's output dataset. This is needed when the same time dimension is requested
twice in one query, with both a custom grain and that custom grain's base grain.
"""

time_dimension_spec: TimeDimensionSpec
include_base_grain: bool

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert (
self.time_dimension_spec.time_granularity.is_custom_granularity
), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity."
f" Instead, found {self.time_dimension_spec.time_granularity.name}. This indicates internal misconfiguration."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec, include_base_grain: bool
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec
) -> JoinToCustomGranularityNode:
return JoinToCustomGranularityNode(
parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec, include_base_grain=include_base_grain
)
return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
Expand All @@ -55,19 +50,14 @@ def description(self) -> str: # noqa: D102
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("time_dimension_spec", self.time_dimension_spec),
DisplayedProperty("include_base_grain", self.include_base_grain),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.time_dimension_spec == self.time_dimension_spec
and other_node.include_base_grain == self.include_base_grain
)
return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
Expand All @@ -76,5 +66,4 @@ def with_new_parents( # noqa: D102
return JoinToCustomGranularityNode.create(
parent_node=new_parent_nodes[0],
time_dimension_spec=self.time_dimension_spec,
include_base_grain=self.include_base_grain,
)
23 changes: 3 additions & 20 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,23 +1464,6 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
join_type=SqlJoinType.LEFT_OUTER,
)

# Remove base grain from parent dataset, unless that grain was also requested (in addition to the custom grain).
parent_instance_set = parent_data_set.instance_set
parent_select_columns = parent_data_set.checked_sql_select_node.select_columns
if not node.include_base_grain:
parent_instance_set = parent_instance_set.transform(
FilterElements(
exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,))
)
)
parent_select_columns = tuple(
[
column
for column in parent_select_columns
if column.column_alias != parent_time_dimension_instance.associated_column.column_name
]
)

# Build output time spine instances and columns.
time_spine_instance = TimeDimensionInstance(
defined_from=parent_time_dimension_instance.defined_from,
Expand All @@ -1498,10 +1481,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
),
)
return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_instance_set, parent_data_set.instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description + "\n" + parent_data_set.checked_sql_select_node.description,
select_columns=parent_select_columns + time_spine_select_columns,
description=parent_data_set.checked_sql_select_node.description + "\n" + node.description,
select_columns=parent_data_set.checked_sql_select_node.select_columns + time_spine_select_columns,
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,),
Expand Down

0 comments on commit 3050414

Please sign in to comment.