Skip to content

Commit

Permalink
Add SQL rendering logic for custom granularities
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 19, 2024
1 parent 2a1775f commit f636c63
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension
aggregation_state=self.aggregation_state,
)

@property
def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity),
date_part=self.date_part,
aggregation_state=self.aggregation_state,
)

def with_grain_and_date_part( # noqa: D102
self, time_granularity: ExpandedTimeGranularity, date_part: Optional[DatePart]
) -> TimeDimensionSpec:
Expand Down
33 changes: 28 additions & 5 deletions metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@

@dataclass(frozen=True)
class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset to convert time dimension to a custom granularity."""
"""Join parent dataset to time spine dataset to convert time dimension to a custom granularity.
Args:
time_dimension_spec: The time dimension spec with a custom granularity that will be satisfied by this node.
include_base_grain: Bool that indicates if a spec with the custom granularity's base grain
should be included in the node's output dataset. This is needed when the same time dimension is requested
twice in one query, with both a custom grain and that custom grain's base grain.
"""

time_dimension_spec: TimeDimensionSpec
include_base_grain: bool

def __post_init__(self) -> None: # noqa: D105
assert (
self.time_dimension_spec.time_granularity.is_custom_granularity
), "Time granularity for time dimension spec in JoinToCustomGranularityNode must be qualified as custom granularity."
f" Instead, found {self.time_dimension_spec.time_granularity.name}. This indicates internal misconfiguration."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec, include_base_grain: bool
) -> JoinToCustomGranularityNode:
return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec)
return JoinToCustomGranularityNode(
parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec, include_base_grain=include_base_grain
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
Expand All @@ -39,19 +55,26 @@ def description(self) -> str: # noqa: D102
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("time_dimension_spec", self.time_dimension_spec),
DisplayedProperty("include_base_grain", self.include_base_grain),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec
return (
isinstance(other_node, self.__class__)
and other_node.time_dimension_spec == self.time_dimension_spec
and other_node.include_base_grain == self.include_base_grain
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> JoinToCustomGranularityNode:
assert len(new_parent_nodes) == 1, "JoinToCustomGranularity accepts exactly one parent node."
return JoinToCustomGranularityNode.create(
parent_node=new_parent_nodes[0], time_dimension_spec=self.time_dimension_spec
parent_node=new_parent_nodes[0],
time_dimension_spec=self.time_dimension_spec,
include_base_grain=self.include_base_grain,
)
118 changes: 110 additions & 8 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def __init__(
self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
)
self._custom_granularity_time_spine_sources = TimeSpineSource.build_custom_time_spine_sources(
tuple(self._time_spine_sources.values())
)

@property
def column_association_resolver(self) -> ColumnAssociationResolver: # noqa: D102
Expand Down Expand Up @@ -262,9 +265,10 @@ def _make_time_spine_data_set(
column_alias = self.column_association_resolver.resolve_spec(agg_time_dimension_instance.spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
# TODO: also handle date part.
# TODO: [custom granularity] add support for custom granularities to make_time_spine_data_set
agg_time_grain = agg_time_dimension_instance.spec.time_granularity
assert not agg_time_grain.is_custom_granularity, "Custom time granularities are not yet supported!"
assert (
not agg_time_grain.is_custom_granularity
), "Custom time granularities are not yet supported for all queries."
if agg_time_grain.base_granularity == time_spine_source.base_granularity:
select_columns += (SqlSelectColumn(expr=column_expr, column_alias=column_alias),)
# If any columns have a different granularity, apply a DATE_TRUNC() and aggregate via group_by.
Expand Down Expand Up @@ -306,6 +310,7 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
instance_set=node.data_set.instance_set,
)

# TODO: write tests for custom granularities that hit this node
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet:
"""Generate time range join SQL."""
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()
Expand Down Expand Up @@ -1229,6 +1234,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
),
)

# TODO: write tests for custom granularities that hit this node
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()
Expand All @@ -1252,10 +1258,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
agg_time_dimension_instances.append(instance)

# Choose the instance with the smallest standard granularity available.
# TODO: [custom granularity] Update to account for custom granularity instances
assert all(
[not instance.spec.time_granularity.is_custom_granularity for instance in agg_time_dimension_instances]
), "Custom granularities are not yet supported!"
), "Custom granularities are not yet supported for all queries."
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
assert len(agg_time_dimension_instances) > 0, (
"Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been "
Expand Down Expand Up @@ -1341,9 +1346,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec

# TODO: this will break when we start supporting smaller grain than DAY unless the time spine table is
# updated to use the smallest available grain.
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
Expand Down Expand Up @@ -1408,8 +1410,108 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
),
)

def _get_time_spine_for_custom_granularity(self, custom_granularity: str) -> TimeSpineSource:
time_spine_source = self._custom_granularity_time_spine_sources.get(custom_granularity)
assert time_spine_source, (
f"Custom granularity {custom_granularity} does not not exist in time spine sources. "
f"Available custom granularities: {list(self._custom_granularity_time_spine_sources.keys())}"
)
return time_spine_source

def _get_custom_granularity_column_name(self, custom_granularity_name: str) -> str:
time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name)
for custom_granularity in time_spine_source.custom_granularities:
print(custom_granularity)
if custom_granularity.name == custom_granularity_name:
return custom_granularity.column_name if custom_granularity.column_name else custom_granularity.name

raise RuntimeError(
f"Custom granularity {custom_granularity} not found. This indicates internal misconfiguration."
)

def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: # noqa: D102
raise NotImplementedError # TODO in later commit
parent_data_set = node.parent_node.accept(self)

# New dataset will be joined to parent dataset without a subquery, so use the same FROM alias as the parent node.
parent_alias = parent_data_set.checked_sql_select_node.from_source_alias
parent_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in parent_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec.with_base_grain:
parent_time_dimension_instance = instance
break
assert parent_time_dimension_instance, (
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. "
"This indicates internal misconfiguration."
)

# Build join expression.
time_spine_alias = self._next_unique_table_alias()
custom_granularity_name = node.time_dimension_spec.time_granularity.name
time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name)
left_expr_for_join: SqlExpressionNode = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_alias, column_name=parent_time_dimension_instance.associated_column.column_name
)
if parent_time_dimension_instance.spec.time_granularity.base_granularity != time_spine_source.base_granularity:
# If needed, apply DATE_TRUNC to parent column match the time spine spine that's column being joined to.
left_expr_for_join = SqlDateTruncExpression.create(
time_granularity=time_spine_source.base_granularity, arg=left_expr_for_join
)
join_description = SqlJoinDescription(
right_source=SqlTableNode.create(sql_table=time_spine_source.spine_table),
right_source_alias=time_spine_alias,
on_condition=SqlComparisonExpression.create(
left_expr=left_expr_for_join,
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=time_spine_source.base_column
),
),
join_type=SqlJoinType.LEFT_OUTER,
)

# Remove base grain from parent dataset, unless that grain was also requested (in addition to the custom grain).
parent_instance_set = parent_data_set.instance_set
parent_select_columns = parent_data_set.checked_sql_select_node.select_columns
if not node.include_base_grain:
parent_instance_set = parent_instance_set.transform(
FilterElements(
exclude_specs=InstanceSpecSet(time_dimension_specs=(parent_time_dimension_instance.spec,))
)
)
parent_select_columns = tuple(
[
column
for column in parent_select_columns
if column.column_alias != parent_time_dimension_instance.associated_column.column_name
]
)

# Build output time spine instances and columns.
time_spine_instance = TimeDimensionInstance(
defined_from=parent_time_dimension_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(node.time_dimension_spec),),
spec=node.time_dimension_spec,
)
time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,))
time_spine_select_columns = (
SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias,
column_name=self._get_custom_granularity_column_name(custom_granularity_name),
),
column_alias=time_spine_instance.associated_column.column_name,
),
)
return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description + "\n" + parent_data_set.checked_sql_select_node.description,
select_columns=parent_select_columns + time_spine_select_columns,
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,),
),
)

def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
Expand Down

0 comments on commit f636c63

Please sign in to comment.