Skip to content

Commit

Permalink
Reduce recursive-call overhead in MetricTimeQueryValidationRule (#1440
Browse files Browse the repository at this point in the history
)

The check in `MetricTimeQueryValidationRule` is called for every metric
in a derived metric's ancestors, so this moves expensive parts of the
check to only where it's needed and caches results when possible to
reduce runtimes. The signature for the validation rule classes were
changed, so there are a number of diff lines related to that.
  • Loading branch information
plypaul authored Oct 2, 2024
1 parent 2ee4e79 commit df26a6b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
ResolverInputForWhereFilterIntersection,
)
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator
from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule
from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule
from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
Expand Down Expand Up @@ -123,9 +125,7 @@ def __init__( # noqa: D107
where_filter_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._post_resolution_query_validator = PostResolutionQueryValidator(
manifest_lookup=self._manifest_lookup,
)
self._post_resolution_query_validator = PostResolutionQueryValidator()
self._where_filter_pattern_factory = where_filter_pattern_factory

@staticmethod
Expand Down Expand Up @@ -491,6 +491,10 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
query_level_issue_set = self._post_resolution_query_validator.validate_query(
resolution_dag=resolution_dag,
resolver_input_for_query=resolver_input_for_query,
validation_rules=(
MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query),
DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query),
),
)

if query_level_issue_set.has_issues:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
class PostResolutionQueryValidationRule(ABC):
"""A validation rule that runs after all query inputs have been resolved to specs."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
) -> None:
self._manifest_lookup = manifest_lookup
self._resolver_input_for_query = resolver_input_for_query

def _get_metric(self, metric_reference: MetricReference) -> Metric:
return self._manifest_lookup.metric_lookup.get_metric(metric_reference)
Expand All @@ -25,7 +28,6 @@ def _get_metric(self, metric_reference: MetricReference) -> Metric:
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
"""Given a metric that exists in a resolution DAG, check that the query is valid.
Expand All @@ -39,7 +41,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
"""Validate the parameters to the query are valid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.issues.parsing.duplicate_metric import DuplicateMetricIssue
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule

logger = logging.getLogger(__name__)
Expand All @@ -20,14 +18,10 @@
class DuplicateMetricValidationRule(PostResolutionQueryValidationRule):
"""Validates that a query does not include the same metric multiple times."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
super().__init__(manifest_lookup=manifest_lookup)

@override
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Expand All @@ -37,7 +31,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
duplicate_metric_references = tuple(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cached_property
from typing import Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand Down Expand Up @@ -32,8 +33,10 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
* Derived metrics with an offset time.g
"""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
super().__init__(manifest_lookup=manifest_lookup)
def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
) -> None:
super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query)

self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
Expand All @@ -43,13 +46,19 @@ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D1
)
)

def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool:
for group_by_item_input in query_resolver_input.group_by_item_inputs:
@cached_property
def _group_by_items_include_metric_time(self) -> bool:
for group_by_item_input in self._resolver_input_for_query.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
return True

return False

def _query_includes_metric_time_or_agg_time_dimension(self, metric_reference: MetricReference) -> bool:
return self._group_by_items_include_metric_time or self._group_by_items_include_agg_time_dimension(
query_resolver_input=self._resolver_input_for_query, metric_reference=metric_reference
)

def _group_by_items_include_agg_time_dimension(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> bool:
Expand All @@ -66,15 +75,9 @@ def _group_by_items_include_agg_time_dimension(
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
metric = self._get_metric(metric_reference)
query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time(
resolver_input_for_query
) or self._group_by_items_include_agg_time_dimension(
query_resolver_input=resolver_input_for_query, metric_reference=metric_reference
)

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Expand All @@ -86,7 +89,7 @@ def validate_metric_in_resolution_dag(
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not query_includes_metric_time_or_agg_time_dimension
and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference)
):
return MetricFlowQueryResolutionIssueSet.from_issue(
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
Expand All @@ -102,7 +105,7 @@ def validate_metric_in_resolution_dag(
for input_metric in metric.input_metrics
)

if has_time_offset and not query_includes_metric_time_or_agg_time_dimension:
if has_time_offset and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference):
return MetricFlowQueryResolutionIssueSet.from_issue(
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
Expand All @@ -119,7 +122,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker
from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
Expand All @@ -26,27 +25,21 @@
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule
from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule


class PostResolutionQueryValidator:
"""Runs query validation rules after query resolution is complete."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
self._manifest_lookup = manifest_lookup
self._validation_rules = (
MetricTimeQueryValidationRule(self._manifest_lookup),
DuplicateMetricValidationRule(self._manifest_lookup),
)

def validate_query(
self, resolution_dag: GroupByItemResolutionDag, resolver_input_for_query: ResolverInputForQuery
self,
resolution_dag: GroupByItemResolutionDag,
resolver_input_for_query: ResolverInputForQuery,
validation_rules: Sequence[PostResolutionQueryValidationRule],
) -> MetricFlowQueryResolutionIssueSet:
"""Validate according to the list of configured validation rules and return a set containing issues found."""
validation_visitor = _PostResolutionQueryValidationVisitor(
resolver_input_for_query=resolver_input_for_query,
validation_rules=self._validation_rules,
validation_rules=validation_rules,
)

return resolution_dag.sink_node.accept(validation_visitor)
Expand Down Expand Up @@ -83,7 +76,6 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> MetricFlow
issue_sets_to_merge.append(
validation_rule.validate_metric_in_resolution_dag(
metric_reference=node.metric_reference,
resolver_input_for_query=self._resolver_input_for_query,
resolution_path=current_traversal_path,
)
)
Expand All @@ -100,7 +92,6 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> MetricFlowQu
validation_rule.validate_query_in_resolution_dag(
metrics_in_query=node.metrics_in_query,
where_filter_intersection=node.where_filter_intersection,
resolver_input_for_query=self._resolver_input_for_query,
resolution_path=current_traversal_path,
)
)
Expand Down

0 comments on commit df26a6b

Please sign in to comment.