diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py index b8da4e7ce6..10a5d0d9b2 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py @@ -82,11 +82,11 @@ def __init__( # Map measures / entities to semantic models that contain them. self._entity_to_semantic_model: Dict[str, List[SemanticModel]] = defaultdict(list) - self._measure_to_semantic_model: Dict[str, List[SemanticModel]] = defaultdict(list) - + self._semantic_model_reference_to_semantic_model: Dict[SemanticModelReference, SemanticModel] = {} for semantic_model in self._semantic_models: for entity in semantic_model.entities: self._entity_to_semantic_model[entity.reference.element_name].append(semantic_model) + self._semantic_model_reference_to_semantic_model[semantic_model.reference] = semantic_model self._metric_to_linkable_element_sets: Dict[str, List[LinkableElementSet]] = {} self._metric_references_to_metrics: Dict[MetricReference, Metric] = {} @@ -94,6 +94,9 @@ def __init__( set ) + # Cache for `_get_joined_elements()`. + self._semantic_model_reference_to_joined_elements: Dict[SemanticModelReference, LinkableElementSet] = {} + start_time = time.time() for metric in self._semantic_manifest.metrics: self._metric_references_to_metrics[MetricReference(metric.name)] = metric @@ -526,8 +529,20 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference } ) - def _get_joined_elements(self, measure_semantic_model: SemanticModel) -> LinkableElementSet: + def _get_joined_elements(self, measure_semantic_model_reference: SemanticModelReference) -> LinkableElementSet: """Get the elements that can be generated by joining other models to the given model.""" + result = self._semantic_model_reference_to_joined_elements.get(measure_semantic_model_reference) + if result is not None: + return result + + result = self._get_joined_elements_without_cache(measure_semantic_model_reference) + self._semantic_model_reference_to_joined_elements[measure_semantic_model_reference] = result + return result + + def _get_joined_elements_without_cache( + self, measure_semantic_model_reference: SemanticModelReference + ) -> LinkableElementSet: + measure_semantic_model = self._semantic_model_reference_to_semantic_model[measure_semantic_model_reference] # Create single-hop elements join_paths = [] for entity in measure_semantic_model.entities: @@ -596,7 +611,7 @@ def _get_linkable_element_set_for_measure( metrics_linked_to_semantic_model = LinkableElementSet() metric_time_elements = self._get_metric_time_elements(measure_reference) - joined_elements = self._get_joined_elements(measure_semantic_model) + joined_elements = self._get_joined_elements(measure_semantic_model.reference) return LinkableElementSet.merge_by_path_key( (