diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 76b7d39..bf0cfa2 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -188,12 +188,16 @@ impl> Sol &self, solvable_id: SolvableId, ) -> Result> { - let mut output = AddClauseOutput::default(); - let mut queue = vec![solvable_id]; - let mut seen = HashSet::new(); - seen.insert(solvable_id); + let output = RefCell::new(AddClauseOutput::default()); + let queue = RefCell::new(vec![solvable_id]); + let seen = RefCell::new(HashSet::new()); + seen.borrow_mut().insert(solvable_id); + + while let Some(solvable_id) = queue.borrow_mut().pop() { + let output = &output; + let queue = &queue; + let seen = &seen; - while let Some(solvable_id) = queue.pop() { let mutex = { let mut clauses = self.clauses_added_for_solvable.borrow_mut(); let mutex = clauses @@ -237,6 +241,7 @@ impl> Sol .alloc(ClauseState::exclude(solvable_id, *reason)); // Exclusions are negative assertions, tracked outside of the watcher system + let mut output = output.borrow_mut(); output.negative_assertions.push((solvable_id, clause_id)); // There might be a conflict now @@ -251,14 +256,10 @@ impl> Sol }; // Add clauses for the requirements - for version_set_id in requirements { + let add_requirements = requirements.into_iter().map(|version_set_id| async move { let dependency_name = self.pool.resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package( - &mut output.negative_assertions, - &mut output.clauses_to_watch, - dependency_name, - ) - .await?; + self.add_clauses_for_package(&output, dependency_name) + .await?; // Find all the solvables that match for the given version set let candidates = self @@ -268,6 +269,8 @@ impl> Sol // Queue requesting the dependencies of the candidates as well if they are cheaply // available from the dependency provider. + let mut queue = queue.borrow_mut(); + let mut seen = seen.borrow_mut(); for &candidate in candidates { if seen.insert(candidate) && self.cache.are_dependencies_available_for(candidate) @@ -292,6 +295,7 @@ impl> Sol unreachable!(); }; + let mut output = output.borrow_mut(); if clause.has_watches() { output.clauses_to_watch.push(clause_id); } @@ -306,17 +310,14 @@ impl> Sol // Add assertions for unit clauses (i.e. those with no matching candidates) output.negative_assertions.push((solvable_id, clause_id)); } - } - // Add clauses for the constraints - for version_set_id in constrains { + Ok::<_, Box>(()) + }); + + let add_constrains = constrains.into_iter().map(|version_set_id| async move { let dependency_name = self.pool.resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package( - &mut output.negative_assertions, - &mut output.clauses_to_watch, - dependency_name, - ) - .await?; + self.add_clauses_for_package(&output, dependency_name) + .await?; // Find all the solvables that match for the given version set let constrained_candidates = self @@ -325,6 +326,7 @@ impl> Sol .await?; // Add forbidden clauses for the candidates + let mut output = output.borrow_mut(); for forbidden_candidate in constrained_candidates.iter().copied().collect_vec() { let (clause, conflict) = ClauseState::constrains( solvable_id, @@ -340,12 +342,22 @@ impl> Sol output.conflicting_clauses.push(clause_id); } } + + Ok::<_, Box>(()) + }); + + let add_requirements = futures::future::join_all(add_requirements); + let add_constrains = futures::future::join_all(add_constrains); + let (results1, results2) = + futures::future::join(add_requirements, add_constrains).await; + for result in results1.into_iter().chain(results2) { + result?; } *clauses_added = true; } - Ok(output) + Ok(output.into_inner()) } /// Adds all clauses for a specific package name. @@ -366,8 +378,7 @@ impl> Sol /// will be returned as an `Err(...)`. async fn add_clauses_for_package( &self, - negative_assertions: &mut Vec<(SolvableId, ClauseId)>, - clauses_to_watch: &mut Vec, + output: &RefCell, package_name: NameId, ) -> Result<(), Box> { let mutex = { @@ -402,6 +413,8 @@ impl> Sol ); } + let mut output = output.borrow_mut(); + // Each candidate gets a clause to disallow other candidates. for (i, &candidate) in candidates.iter().enumerate() { for &other_candidate in &candidates[i + 1..] { @@ -411,7 +424,7 @@ impl> Sol .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); - clauses_to_watch.push(clause_id); + output.clauses_to_watch.push(clause_id); } } @@ -425,7 +438,7 @@ impl> Sol .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); - clauses_to_watch.push(clause_id); + output.clauses_to_watch.push(clause_id); } } } @@ -438,7 +451,7 @@ impl> Sol .alloc(ClauseState::exclude(solvable, reason)); // Exclusions are negative assertions, tracked outside of the watcher system - negative_assertions.push((solvable, clause_id)); + output.negative_assertions.push((solvable, clause_id)); // Conflicts should be impossible here debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true));