From ee0f20bb97ee834d79f69506a959ba9a78f16683 Mon Sep 17 00:00:00 2001 From: lcnr Date: Mon, 13 May 2024 15:50:59 +0000 Subject: [PATCH] move global cache lookup into fn --- .../rustc_middle/src/traits/solve/cache.rs | 34 ++++---- .../src/solve/search_graph.rs | 86 ++++++++++--------- 2 files changed, 62 insertions(+), 58 deletions(-) diff --git a/compiler/rustc_middle/src/traits/solve/cache.rs b/compiler/rustc_middle/src/traits/solve/cache.rs index 03ce7cf98cf7e..2ff4ade21d087 100644 --- a/compiler/rustc_middle/src/traits/solve/cache.rs +++ b/compiler/rustc_middle/src/traits/solve/cache.rs @@ -14,11 +14,11 @@ pub struct EvaluationCache<'tcx> { map: Lock, CacheEntry<'tcx>>>, } -#[derive(PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub struct CacheData<'tcx> { pub result: QueryResult<'tcx>, pub proof_tree: Option<&'tcx [inspect::GoalEvaluationStep>]>, - pub reached_depth: usize, + pub additional_depth: usize, pub encountered_overflow: bool, } @@ -29,7 +29,7 @@ impl<'tcx> EvaluationCache<'tcx> { tcx: TyCtxt<'tcx>, key: CanonicalInput<'tcx>, proof_tree: Option<&'tcx [inspect::GoalEvaluationStep>]>, - reached_depth: usize, + additional_depth: usize, encountered_overflow: bool, cycle_participants: FxHashSet>, dep_node: DepNodeIndex, @@ -40,17 +40,17 @@ impl<'tcx> EvaluationCache<'tcx> { let data = WithDepNode::new(dep_node, QueryData { result, proof_tree }); entry.cycle_participants.extend(cycle_participants); if encountered_overflow { - entry.with_overflow.insert(reached_depth, data); + entry.with_overflow.insert(additional_depth, data); } else { - entry.success = Some(Success { data, reached_depth }); + entry.success = Some(Success { data, additional_depth }); } if cfg!(debug_assertions) { drop(map); - if Some(CacheData { result, proof_tree, reached_depth, encountered_overflow }) - != self.get(tcx, key, |_| false, Limit(reached_depth)) - { - bug!("unable to retrieve inserted element from cache: {key:?}"); + let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow }; + let actual = self.get(tcx, key, [], Limit(additional_depth)); + if !actual.as_ref().is_some_and(|actual| expected == *actual) { + bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}"); } } } @@ -63,23 +63,25 @@ impl<'tcx> EvaluationCache<'tcx> { &self, tcx: TyCtxt<'tcx>, key: CanonicalInput<'tcx>, - cycle_participant_in_stack: impl FnOnce(&FxHashSet>) -> bool, + stack_entries: impl IntoIterator>, available_depth: Limit, ) -> Option> { let map = self.map.borrow(); let entry = map.get(&key)?; - if cycle_participant_in_stack(&entry.cycle_participants) { - return None; + for stack_entry in stack_entries { + if entry.cycle_participants.contains(&stack_entry) { + return None; + } } if let Some(ref success) = entry.success { - if available_depth.value_within_limit(success.reached_depth) { + if available_depth.value_within_limit(success.additional_depth) { let QueryData { result, proof_tree } = success.data.get(tcx); return Some(CacheData { result, proof_tree, - reached_depth: success.reached_depth, + additional_depth: success.additional_depth, encountered_overflow: false, }); } @@ -90,7 +92,7 @@ impl<'tcx> EvaluationCache<'tcx> { CacheData { result, proof_tree, - reached_depth: available_depth.0, + additional_depth: available_depth.0, encountered_overflow: true, } }) @@ -99,7 +101,7 @@ impl<'tcx> EvaluationCache<'tcx> { struct Success<'tcx> { data: WithDepNode>, - reached_depth: usize, + additional_depth: usize, } #[derive(Clone, Copy)] diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs index 87e4fd9ae7368..6cc674dcfed1c 100644 --- a/compiler/rustc_trait_selection/src/solve/search_graph.rs +++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs @@ -134,16 +134,6 @@ impl SearchGraph { self.mode } - /// Update the stack and reached depths on cache hits. - #[instrument(level = "trace", skip(self))] - fn on_cache_hit(&mut self, additional_depth: usize, encountered_overflow: bool) { - let reached_depth = self.stack.next_index().plus(additional_depth); - if let Some(last) = self.stack.raw.last_mut() { - last.reached_depth = last.reached_depth.max(reached_depth); - last.encountered_overflow |= encountered_overflow; - } - } - /// Pops the highest goal from the stack, lazily updating the /// the next goal in the stack. /// @@ -276,37 +266,7 @@ impl<'tcx> SearchGraph> { return Self::response_no_constraints(tcx, input, Certainty::overflow(true)); }; - // Try to fetch the goal from the global cache. - 'global: { - let Some(CacheData { result, proof_tree, reached_depth, encountered_overflow }) = - self.global_cache(tcx).get( - tcx, - input, - |cycle_participants| { - self.stack.iter().any(|entry| cycle_participants.contains(&entry.input)) - }, - available_depth, - ) - else { - break 'global; - }; - - // If we're building a proof tree and the current cache entry does not - // contain a proof tree, we do not use the entry but instead recompute - // the goal. We simply overwrite the existing entry once we're done, - // caching the proof tree. - if !inspect.is_noop() { - if let Some(revisions) = proof_tree { - inspect.goal_evaluation_kind( - inspect::WipCanonicalGoalEvaluationKind::Interned { revisions }, - ); - } else { - break 'global; - } - } - - self.on_cache_hit(reached_depth, encountered_overflow); - debug!("global cache hit"); + if let Some(result) = self.lookup_global_cache(tcx, input, available_depth, inspect) { return result; } @@ -388,7 +348,10 @@ impl<'tcx> SearchGraph> { // This is for global caching, so we properly track query dependencies. // Everything that affects the `result` should be performed within this - // `with_anon_task` closure. + // `with_anon_task` closure. If computing this goal depends on something + // not tracked by the cache key and from outside of this anon task, it + // must not be added to the global cache. Notably, this is the case for + // trait solver cycles participants. let ((final_entry, result), dep_node) = tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || { for _ in 0..FIXPOINT_STEP_LIMIT { @@ -446,6 +409,45 @@ impl<'tcx> SearchGraph> { result } + + /// Try to fetch a previously computed result from the global cache, + /// making sure to only do so if it would match the result of reevaluating + /// this goal. + fn lookup_global_cache( + &mut self, + tcx: TyCtxt<'tcx>, + input: CanonicalInput<'tcx>, + available_depth: Limit, + inspect: &mut ProofTreeBuilder>, + ) -> Option> { + let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self + .global_cache(tcx) + .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?; + + // If we're building a proof tree and the current cache entry does not + // contain a proof tree, we do not use the entry but instead recompute + // the goal. We simply overwrite the existing entry once we're done, + // caching the proof tree. + if !inspect.is_noop() { + if let Some(revisions) = proof_tree { + let kind = inspect::WipCanonicalGoalEvaluationKind::Interned { revisions }; + inspect.goal_evaluation_kind(kind); + } else { + return None; + } + } + + // Update the reached depth of the current goal to make sure + // its state is the same regardless of whether we've used the + // global cache or not. + let reached_depth = self.stack.next_index().plus(additional_depth); + if let Some(last) = self.stack.raw.last_mut() { + last.reached_depth = last.reached_depth.max(reached_depth); + last.encountered_overflow |= encountered_overflow; + } + + Some(result) + } } enum StepResult<'tcx> {