diff --git a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py index dc6457c73a..f11883d7b2 100644 --- a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from enum import Enum +from functools import cache from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME @@ -190,8 +191,18 @@ def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ()) exclude_fields=exclude_fields, ) - @staticmethod + @classmethod + @cache + def _get_compatible_grain_and_date_part(cls) -> Sequence[Tuple[ExpandedTimeGranularity, DatePart]]: + items = [] + for date_part in DatePart: + for compatible_granularity in date_part.compatible_granularities: + items.append((ExpandedTimeGranularity.from_time_granularity(compatible_granularity), date_part)) + return items + + @classmethod def generate_possible_specs_for_time_dimension( + cls, time_dimension_reference: TimeDimensionReference, entity_links: Tuple[EntityReference, ...], custom_granularities: Dict[str, ExpandedTimeGranularity], @@ -210,16 +221,15 @@ def generate_possible_specs_for_time_dimension( date_part=None, ) ) - for date_part in DatePart: - for compatible_granularity in date_part.compatible_granularities: - time_dimension_specs.append( - TimeDimensionSpec( - element_name=time_dimension_reference.element_name, - entity_links=entity_links, - time_granularity=ExpandedTimeGranularity.from_time_granularity(compatible_granularity), - date_part=date_part, - ) + for grain, date_part in cls._get_compatible_grain_and_date_part(): + time_dimension_specs.append( + TimeDimensionSpec( + element_name=time_dimension_reference.element_name, + entity_links=entity_links, + time_granularity=grain, + date_part=date_part, ) + ) return time_dimension_specs @property diff --git a/metricflow-semantics/metricflow_semantics/time/granularity.py b/metricflow-semantics/metricflow_semantics/time/granularity.py index cf036b869c..ff442c47bb 100644 --- a/metricflow-semantics/metricflow_semantics/time/granularity.py +++ b/metricflow-semantics/metricflow_semantics/time/granularity.py @@ -1,6 +1,8 @@ from __future__ import annotations from dataclasses import dataclass +from functools import cache, cached_property +from typing import ClassVar, FrozenSet from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity @@ -32,20 +34,28 @@ def __post_init__(self) -> None: f"{self.base_granularity}." ) - @property + @cached_property def is_custom_granularity(self) -> bool: # noqa: D102 return self.base_granularity.value != self.name @classmethod + @cache def from_time_granularity(cls, granularity: TimeGranularity) -> ExpandedTimeGranularity: - """Factory method for creating an ExpandedTimeGranularity from a standard TimeGranularity enumeration value.""" + """Factory method for creating an ExpandedTimeGranularity from a standard TimeGranularity enumeration value. + + This should be safe to use with `@cache` since the number of `TimeGranularity` is small and limited. + """ return ExpandedTimeGranularity(name=granularity.value, base_granularity=granularity) - @staticmethod - def is_standard_granularity_name(time_granularity_name: str) -> bool: - """Helper for checking if a given time granularity name is part of the standard TimeGranularity enumeration.""" - for granularity in TimeGranularity: - if time_granularity_name == granularity.value: - return True + @classmethod + @cache + def _standard_time_granularity_names(cls) -> FrozenSet: + """This should be safe to use with `@cache` since the number of `TimeGranularity` is small and limited.""" + return frozenset( + granularity.value for granularity in TimeGranularity + ) - return False + @classmethod + def is_standard_granularity_name(cls, time_granularity_name: str) -> bool: + """Helper for checking if a given time granularity name is part of the standard TimeGranularity enumeration.""" + return time_granularity_name in cls._standard_time_granularity_names()