Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce recursive-call overhead in MetricTimeQueryValidationRule #1440

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def __init__(
max_entity_links=MAX_JOIN_HOPS,
)

# Cache for `get_min_queryable_time_granularity()`
self._metric_reference_to_min_metric_time_grain: Dict[MetricReference, TimeGranularity] = {}

# Cache for `get_valid_agg_time_dimensions_for_metric()`.
self._metric_reference_to_valid_agg_time_dimension_specs: Dict[
MetricReference, Sequence[TimeDimensionSpec]
] = {}

@functools.lru_cache
def linkable_elements_for_measure(
self,
Expand Down Expand Up @@ -183,6 +191,18 @@ def get_valid_agg_time_dimensions_for_metric(
self, metric_reference: MetricReference
) -> Sequence[TimeDimensionSpec]:
"""Get the agg time dimension specs that can be used in place of metric time for this metric, if applicable."""
result = self._metric_reference_to_valid_agg_time_dimension_specs.get(metric_reference)
if result is not None:
return result

result = self._get_valid_agg_time_dimensions_for_metric(metric_reference)
self._metric_reference_to_valid_agg_time_dimension_specs[metric_reference] = result

return result

def _get_valid_agg_time_dimensions_for_metric(
self, metric_reference: MetricReference
) -> Sequence[TimeDimensionSpec]:
agg_time_dimension_specs = self._get_agg_time_dimension_specs_for_metric(metric_reference)
distinct_agg_time_dimension_identifiers = set(
[(spec.reference, spec.entity_links) for spec in agg_time_dimension_specs]
Expand All @@ -204,6 +224,15 @@ def get_min_queryable_time_granularity(self, metric_reference: MetricReference)

Maps to the largest granularity defined for any of the metric's agg_time_dimensions.
"""
result = self._metric_reference_to_min_metric_time_grain.get(metric_reference)
if result is not None:
return result

result = self._get_min_queryable_time_granularity(metric_reference)
self._metric_reference_to_min_metric_time_grain[metric_reference] = result
return result

def _get_min_queryable_time_granularity(self, metric_reference: MetricReference) -> TimeGranularity:
agg_time_dimension_specs = self._get_agg_time_dimension_specs_for_metric(metric_reference)
assert (
agg_time_dimension_specs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
# Cache for defined time granularity.
self._time_dimension_to_defined_time_granularity: Dict[TimeDimensionReference, TimeGranularity] = {}

# Cache for agg. time dimension for measure.
self._measure_reference_to_agg_time_dimension_specs: Dict[MeasureReference, Sequence[TimeDimensionSpec]] = {}

def get_dimension_references(self) -> Sequence[DimensionReference]:
"""Retrieve all dimension references from the collection of semantic models."""
return tuple(self._dimension_index.keys())
Expand Down Expand Up @@ -364,6 +367,17 @@ def get_agg_time_dimension_specs_for_measure(
self, measure_reference: MeasureReference
) -> Sequence[TimeDimensionSpec]:
"""Get the agg time dimension specs that can be used in place of metric time for this measure."""
result = self._measure_reference_to_agg_time_dimension_specs.get(measure_reference)
if result is not None:
return result

result = self._get_agg_time_dimension_specs_for_measure(measure_reference)
self._measure_reference_to_agg_time_dimension_specs[measure_reference] = result
return result

def _get_agg_time_dimension_specs_for_measure(
self, measure_reference: MeasureReference
) -> Sequence[TimeDimensionSpec]:
agg_time_dimension = self.get_agg_time_dimension_for_measure(measure_reference)
# A measure's agg_time_dimension is required to be in the same semantic model as the measure,
# so we can assume the same semantic model for both measure and dimension.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
ResolverInputForWhereFilterIntersection,
)
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator
from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule
from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule
from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
Expand Down Expand Up @@ -123,9 +125,7 @@ def __init__( # noqa: D107
where_filter_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._post_resolution_query_validator = PostResolutionQueryValidator(
manifest_lookup=self._manifest_lookup,
)
self._post_resolution_query_validator = PostResolutionQueryValidator()
self._where_filter_pattern_factory = where_filter_pattern_factory

@staticmethod
Expand Down Expand Up @@ -491,6 +491,10 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
query_level_issue_set = self._post_resolution_query_validator.validate_query(
resolution_dag=resolution_dag,
resolver_input_for_query=resolver_input_for_query,
validation_rules=(
MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query),
DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query),
),
)

if query_level_issue_set.has_issues:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
class PostResolutionQueryValidationRule(ABC):
"""A validation rule that runs after all query inputs have been resolved to specs."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
) -> None:
self._manifest_lookup = manifest_lookup
self._resolver_input_for_query = resolver_input_for_query

def _get_metric(self, metric_reference: MetricReference) -> Metric:
return self._manifest_lookup.metric_lookup.get_metric(metric_reference)
Expand All @@ -25,7 +28,6 @@ def _get_metric(self, metric_reference: MetricReference) -> Metric:
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
"""Given a metric that exists in a resolution DAG, check that the query is valid.
Expand All @@ -39,7 +41,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
"""Validate the parameters to the query are valid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.issues.parsing.duplicate_metric import DuplicateMetricIssue
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule

logger = logging.getLogger(__name__)
Expand All @@ -20,14 +18,10 @@
class DuplicateMetricValidationRule(PostResolutionQueryValidationRule):
"""Validates that a query does not include the same metric multiple times."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
super().__init__(manifest_lookup=manifest_lookup)

@override
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Expand All @@ -37,7 +31,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
duplicate_metric_references = tuple(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cached_property
from typing import Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand Down Expand Up @@ -32,8 +33,10 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
* Derived metrics with an offset time.g
"""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
super().__init__(manifest_lookup=manifest_lookup)
def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
) -> None:
super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query)

self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
Expand All @@ -43,13 +46,19 @@ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D1
)
)

def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool:
for group_by_item_input in query_resolver_input.group_by_item_inputs:
@cached_property
def _group_by_items_include_metric_time(self) -> bool:
for group_by_item_input in self._resolver_input_for_query.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
return True

return False

def _query_includes_metric_time_or_agg_time_dimension(self, metric_reference: MetricReference) -> bool:
return self._group_by_items_include_metric_time or self._group_by_items_include_agg_time_dimension(
query_resolver_input=self._resolver_input_for_query, metric_reference=metric_reference
)

def _group_by_items_include_agg_time_dimension(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> bool:
Expand All @@ -66,15 +75,9 @@ def _group_by_items_include_agg_time_dimension(
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
metric = self._get_metric(metric_reference)
query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time(
resolver_input_for_query
) or self._group_by_items_include_agg_time_dimension(
query_resolver_input=resolver_input_for_query, metric_reference=metric_reference
)

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Expand All @@ -86,7 +89,7 @@ def validate_metric_in_resolution_dag(
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not query_includes_metric_time_or_agg_time_dimension
and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference)
):
return MetricFlowQueryResolutionIssueSet.from_issue(
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
Expand All @@ -102,7 +105,7 @@ def validate_metric_in_resolution_dag(
for input_metric in metric.input_metrics
)

if has_time_offset and not query_includes_metric_time_or_agg_time_dimension:
if has_time_offset and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference):
return MetricFlowQueryResolutionIssueSet.from_issue(
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
Expand All @@ -119,7 +122,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker
from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
Expand All @@ -26,27 +25,21 @@
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule
from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule


class PostResolutionQueryValidator:
"""Runs query validation rules after query resolution is complete."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
self._manifest_lookup = manifest_lookup
self._validation_rules = (
MetricTimeQueryValidationRule(self._manifest_lookup),
DuplicateMetricValidationRule(self._manifest_lookup),
)

def validate_query(
self, resolution_dag: GroupByItemResolutionDag, resolver_input_for_query: ResolverInputForQuery
self,
resolution_dag: GroupByItemResolutionDag,
resolver_input_for_query: ResolverInputForQuery,
validation_rules: Sequence[PostResolutionQueryValidationRule],
) -> MetricFlowQueryResolutionIssueSet:
"""Validate according to the list of configured validation rules and return a set containing issues found."""
validation_visitor = _PostResolutionQueryValidationVisitor(
resolver_input_for_query=resolver_input_for_query,
validation_rules=self._validation_rules,
validation_rules=validation_rules,
)

return resolution_dag.sink_node.accept(validation_visitor)
Expand Down Expand Up @@ -83,7 +76,6 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> MetricFlow
issue_sets_to_merge.append(
validation_rule.validate_metric_in_resolution_dag(
metric_reference=node.metric_reference,
resolver_input_for_query=self._resolver_input_for_query,
resolution_path=current_traversal_path,
)
)
Expand All @@ -100,7 +92,6 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> MetricFlowQu
validation_rule.validate_query_in_resolution_dag(
metrics_in_query=node.metrics_in_query,
where_filter_intersection=node.where_filter_intersection,
resolver_input_for_query=self._resolver_input_for_query,
resolution_path=current_traversal_path,
)
)
Expand Down
Loading