diff --git a/cpp/include/resolvo.h b/cpp/include/resolvo.h index 97d00f5..5343ac7 100644 --- a/cpp/include/resolvo.h +++ b/cpp/include/resolvo.h @@ -4,6 +4,7 @@ #include "resolvo_internal.h" namespace resolvo { +using cbindgen_private::ConditionalRequirement; using cbindgen_private::Problem; using cbindgen_private::Requirement; @@ -24,6 +25,23 @@ inline Requirement requirement_union(VersionSetUnionId id) { return cbindgen_private::resolvo_requirement_union(id); } +/** + * Specifies a conditional requirement (dependency) of a single version set. + * A solvable belonging to the version set satisfies the requirement if the condition is true. + */ +inline ConditionalRequirement conditional_requirement_single(VersionSetId id) { + return cbindgen_private::resolvo_conditional_requirement_single(id); +} + +/** + * Specifies a conditional requirement (dependency) of the union (logical OR) of multiple version + * sets. A solvable belonging to any of the version sets contained in the union satisfies the + * requirement if the condition is true. + */ +inline ConditionalRequirement conditional_requirement_union(VersionSetUnionId id) { + return cbindgen_private::resolvo_conditional_requirement_union(id); +} + /** * Called to solve a package problem. * diff --git a/cpp/src/lib.rs b/cpp/src/lib.rs index 781e365..a35b576 100644 --- a/cpp/src/lib.rs +++ b/cpp/src/lib.rs @@ -31,6 +31,95 @@ impl From for resolvo::SolvableId { } } +/// A wrapper around an optional version set id. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct FfiOptionVersionSetId { + pub is_some: bool, + pub value: VersionSetId, +} + +impl From> for FfiOptionVersionSetId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v.into(), + }, + None => Self { + is_some: false, + value: VersionSetId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionVersionSetId) -> Self { + if ffi.is_some { + Some(ffi.value.into()) + } else { + None + } + } +} + +impl From> for FfiOptionVersionSetId { + fn from(opt: Option) -> Self { + match opt { + Some(v) => Self { + is_some: true, + value: v, + }, + None => Self { + is_some: false, + value: VersionSetId { id: 0 }, + }, + } + } +} + +impl From for Option { + fn from(ffi: FfiOptionVersionSetId) -> Self { + if ffi.is_some { + Some(ffi.value) + } else { + None + } + } +} + +/// Specifies a conditional requirement, where the requirement is only active when the condition is met. +/// First VersionSetId is the condition, second is the requirement. +/// cbindgen:derive-eq +/// cbindgen:derive-neq +#[repr(C)] +#[derive(Copy, Clone)] +pub struct ConditionalRequirement { + pub condition: FfiOptionVersionSetId, + pub requirement: Requirement, +} + +impl From for ConditionalRequirement { + fn from(value: resolvo::ConditionalRequirement) -> Self { + Self { + condition: value.condition.into(), + requirement: value.requirement.into(), + } + } +} + +impl From for resolvo::ConditionalRequirement { + fn from(value: ConditionalRequirement) -> Self { + Self { + condition: value.condition.into(), + requirement: value.requirement.into(), + } + } +} + /// Specifies the dependency of a solvable on a set of version sets. /// cbindgen:derive-eq /// cbindgen:derive-neq @@ -162,7 +251,7 @@ pub struct Dependencies { /// A pointer to the first element of a list of requirements. Requirements /// defines which packages should be installed alongside the depending /// package and the constraints applied to the package. - pub requirements: Vector, + pub requirements: Vector, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set @@ -475,7 +564,7 @@ impl<'d> resolvo::DependencyProvider for &'d DependencyProvider { #[repr(C)] pub struct Problem<'a> { - pub requirements: Slice<'a, Requirement>, + pub requirements: Slice<'a, ConditionalRequirement>, pub constraints: Slice<'a, VersionSetId>, pub soft_requirements: Slice<'a, SolvableId>, } @@ -525,6 +614,28 @@ pub extern "C" fn resolvo_solve( } } +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_conditional_requirement_single( + version_set_id: VersionSetId, +) -> ConditionalRequirement { + ConditionalRequirement { + condition: Option::::None.into(), + requirement: Requirement::Single(version_set_id), + } +} + +#[no_mangle] +#[allow(unused)] +pub extern "C" fn resolvo_conditional_requirement_union( + version_set_union_id: VersionSetUnionId, +) -> ConditionalRequirement { + ConditionalRequirement { + condition: Option::::None.into(), + requirement: Requirement::Union(version_set_union_id), + } +} + #[no_mangle] #[allow(unused)] pub extern "C" fn resolvo_requirement_single(version_set_id: VersionSetId) -> Requirement { diff --git a/cpp/tests/solve.cpp b/cpp/tests/solve.cpp index 1bb02b7..952e86e 100644 --- a/cpp/tests/solve.cpp +++ b/cpp/tests/solve.cpp @@ -48,16 +48,17 @@ struct PackageDatabase : public resolvo::DependencyProvider { /** * Allocates a new requirement for a single version set. */ - resolvo::Requirement alloc_requirement(std::string_view package, uint32_t version_start, - uint32_t version_end) { + resolvo::ConditionalRequirement alloc_requirement(std::string_view package, + uint32_t version_start, + uint32_t version_end) { auto id = alloc_version_set(package, version_start, version_end); - return resolvo::requirement_single(id); + return resolvo::conditional_requirement_single(id); } /** * Allocates a new requirement for a version set union. */ - resolvo::Requirement alloc_requirement_union( + resolvo::ConditionalRequirement alloc_requirement_union( std::initializer_list> version_sets) { std::vector version_set_union{version_sets.size()}; @@ -69,7 +70,7 @@ struct PackageDatabase : public resolvo::DependencyProvider { auto id = resolvo::VersionSetUnionId{static_cast(version_set_unions.size())}; version_set_unions.push_back(std::move(version_set_union)); - return resolvo::requirement_union(id); + return resolvo::conditional_requirement_union(id); } /** @@ -219,7 +220,8 @@ SCENARIO("Solve") { const auto d_1 = db.alloc_candidate("d", 1, {}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = {db.alloc_requirement("a", 1, 3)}; + resolvo::Vector requirements = { + db.alloc_requirement("a", 1, 3)}; resolvo::Vector constraints = { db.alloc_version_set("b", 1, 3), db.alloc_version_set("c", 1, 3), @@ -263,7 +265,7 @@ SCENARIO("Solve Union") { "f", 1, {{db.alloc_requirement("b", 1, 10)}, {db.alloc_version_set("a", 10, 20)}}); // Construct a problem to be solved by the solver - resolvo::Vector requirements = { + resolvo::Vector requirements = { db.alloc_requirement_union({{"c", 1, 10}, {"d", 1, 10}}), db.alloc_requirement("e", 1, 10), db.alloc_requirement("f", 1, 10), diff --git a/src/conflict.rs b/src/conflict.rs index 3d121b6..bc72678 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -11,15 +11,11 @@ use petgraph::{ Direction, }; -use crate::solver::variable_map::VariableOrigin; use crate::{ internal::{ arena::ArenaId, id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, - }, - runtime::AsyncRuntime, - solver::{clause::Clause, Solver}, - DependencyProvider, Interner, Requirement, + }, requirement::Condition, runtime::AsyncRuntime, solver::{clause::Clause, variable_map::VariableOrigin, Solver}, DependencyProvider, Interner, Requirement }; /// Represents the cause of the solver being unable to find a solution @@ -160,6 +156,59 @@ impl Conflict { ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)), ); } + Clause::Conditional( + package_id, + condition_variables, + requirement, + ) => { + let solvable = package_id + .as_solvable_or_root(&solver.variable_map) + .expect("only solvables can be excluded"); + let package_node = Self::add_node(&mut graph, &mut nodes, solvable); + + let requirement_candidates = solver + .async_runtime + .block_on(solver.cache.get_or_cache_sorted_candidates( + *requirement, + )) + .unwrap_or_else(|_| { + unreachable!( + "The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`" + ) + }); + + if requirement_candidates.is_empty() { + tracing::trace!( + "{package_id:?} conditionally requires {requirement:?}, which has no candidates" + ); + graph.add_edge( + package_node, + unresolved_node, + ConflictEdge::ConditionalRequires( + *requirement, + condition_variables.iter().map(|(_, condition)| *condition).collect(), + ), + ); + } else { + tracing::trace!( + "{package_id:?} conditionally requires {requirement:?} if {condition_variables:?}" + ); + + for &candidate_id in requirement_candidates { + let candidate_node = + Self::add_node(&mut graph, &mut nodes, candidate_id.into()); + + graph.add_edge( + package_node, + candidate_node, + ConflictEdge::ConditionalRequires( + *requirement, + condition_variables.iter().map(|(_, condition)| *condition).collect(), + ), + ); + } + } + } } } @@ -210,7 +259,7 @@ impl Conflict { } /// A node in the graph representation of a [`Conflict`] -#[derive(Copy, Clone, Eq, PartialEq)] +#[derive(Copy, Clone, Eq, PartialEq, Debug)] pub(crate) enum ConflictNode { /// Node corresponding to a solvable Solvable(SolvableOrRootId), @@ -239,33 +288,41 @@ impl ConflictNode { } /// An edge in the graph representation of a [`Conflict`] -#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Debug)] pub(crate) enum ConflictEdge { /// The target node is a candidate for the dependency specified by the /// [`Requirement`] Requires(Requirement), /// The target node is involved in a conflict, caused by `ConflictCause` Conflict(ConflictCause), + /// The target node is a candidate for a conditional dependency + ConditionalRequires(Requirement, Vec), } impl ConflictEdge { - fn try_requires(self) -> Option { + fn try_requires_or_conditional(self) -> Option<(Requirement, Vec)> { match self { - ConflictEdge::Requires(match_spec_id) => Some(match_spec_id), + ConflictEdge::Requires(match_spec_id) => Some((match_spec_id, vec![])), + ConflictEdge::ConditionalRequires(match_spec_id, conditions) => { + Some((match_spec_id, conditions)) + } ConflictEdge::Conflict(_) => None, } } - fn requires(self) -> Requirement { + fn requires_or_conditional(self) -> (Requirement, Vec) { match self { - ConflictEdge::Requires(match_spec_id) => match_spec_id, + ConflictEdge::Requires(match_spec_id) => (match_spec_id, vec![]), + ConflictEdge::ConditionalRequires(match_spec_id, conditions) => { + (match_spec_id, conditions) + } ConflictEdge::Conflict(_) => panic!("expected requires edge, found conflict"), } } } /// Conflict causes -#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Debug)] pub(crate) enum ConflictCause { /// The solvable is locked Locked(SolvableId), @@ -341,6 +398,11 @@ impl ConflictGraph { ConflictEdge::Requires(_) if target != ConflictNode::UnresolvedDependency => { "black" } + ConflictEdge::ConditionalRequires(_, _) + if target != ConflictNode::UnresolvedDependency => + { + "blue" // This indicates that the requirement has candidates, but the condition is not met + } _ => "red", }; @@ -348,6 +410,16 @@ impl ConflictGraph { ConflictEdge::Requires(requirement) => { requirement.display(interner).to_string() } + ConflictEdge::ConditionalRequires(requirement, conditions) => { + format!( + "if {} then {}", + conditions.iter() + .map(|c| interner.display_condition(*c).to_string()) + .collect::>() + .join(" and "), + requirement.display(interner) + ) + } ConflictEdge::Conflict(ConflictCause::Constrains(version_set_id)) => { interner.display_version_set(*version_set_id).to_string() } @@ -493,10 +565,15 @@ impl ConflictGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Requires(req) => ((req, vec![]), e.target()), + ConflictEdge::ConditionalRequires(req, conditions) => { + ((req, conditions.clone()), e.target()) + } ConflictEdge::Conflict(_) => unreachable!(), }) - .chunk_by(|(&version_set_id, _)| version_set_id); + .collect::>() + .into_iter() + .chunk_by(|((&version_set_id, condition), _)| (version_set_id, condition.clone())); for (_, mut deps) in &dependencies { if deps.all(|(_, target)| !installable.contains(&target)) { @@ -539,10 +616,15 @@ impl ConflictGraph { .graph .edges_directed(nx, Direction::Outgoing) .map(|e| match e.weight() { - ConflictEdge::Requires(version_set_id) => (version_set_id, e.target()), + ConflictEdge::Requires(version_set_id) => ((version_set_id, vec![]), e.target()), + ConflictEdge::ConditionalRequires(reqs, conditions) => { + ((reqs, conditions.clone()), e.target()) + } ConflictEdge::Conflict(_) => unreachable!(), }) - .chunk_by(|(&version_set_id, _)| version_set_id); + .collect::>() + .into_iter() + .chunk_by(|((&version_set_id, condition), _)| (version_set_id, condition.clone())); // Missing if at least one dependency is missing if dependencies @@ -629,42 +711,6 @@ impl Indenter { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_indenter_without_top_level_indent() { - let indenter = Indenter::new(false); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), ""); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), "└─ "); - } - - #[test] - fn test_indenter_with_multiple_siblings() { - let indenter = Indenter::new(true); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), "└─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); - assert_eq!(indenter.get_indent(), " ├─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), " │ └─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), " │ └─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); - assert_eq!(indenter.get_indent(), " │ ├─ "); - } -} - /// A struct implementing [`fmt::Display`] that generates a user-friendly /// representation of a conflict graph pub struct DisplayUnsat<'i, I: Interner> { @@ -697,11 +743,13 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { top_level_indent: bool, ) -> fmt::Result { pub enum DisplayOp { + ConditionalRequirement((Requirement, Vec), Vec), Requirement(Requirement, Vec), Candidate(NodeIndex), } let graph = &self.graph.graph; + println!("graph {:?}", graph); let installable_nodes = &self.installable_set; let mut reported: HashSet = HashSet::new(); @@ -709,21 +757,26 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { let indenter = Indenter::new(top_level_indent); let mut stack = top_level_edges .iter() - .filter(|e| e.weight().try_requires().is_some()) - .chunk_by(|e| e.weight().requires()) + .filter(|e| e.weight().clone().try_requires_or_conditional().is_some()) + .chunk_by(|e| e.weight().clone().requires_or_conditional()) .into_iter() - .map(|(version_set_id, group)| { + .map(|(version_set_id_with_condition, group)| { let edges: Vec<_> = group.map(|e| e.id()).collect(); - (version_set_id, edges) + (version_set_id_with_condition, edges) }) - .sorted_by_key(|(_version_set_id, edges)| { + .sorted_by_key(|(_version_set_id_with_condition, edges)| { edges .iter() .any(|&edge| installable_nodes.contains(&graph.edge_endpoints(edge).unwrap().1)) }) - .map(|(version_set_id, edges)| { + .map(|((version_set_id, condition), edges)| { ( - DisplayOp::Requirement(version_set_id, edges), + if !condition.is_empty() { + println!("conditional requirement"); + DisplayOp::ConditionalRequirement((version_set_id, condition), edges) + } else { + DisplayOp::Requirement(version_set_id, edges) + }, indenter.push_level(), ) }) @@ -957,7 +1010,7 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { writeln!(f, "{indent}{version} would require",)?; let mut requirements = graph .edges(candidate) - .chunk_by(|e| e.weight().requires()) + .chunk_by(|e| e.weight().clone().requires_or_conditional()) .into_iter() .map(|(version_set_id, group)| { let edges: Vec<_> = group.map(|e| e.id()).collect(); @@ -969,9 +1022,16 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { .contains(&graph.edge_endpoints(edge).unwrap().1) }) }) - .map(|(version_set_id, edges)| { + .map(|((version_set_id, condition), edges)| { ( - DisplayOp::Requirement(version_set_id, edges), + if !condition.is_empty() { + DisplayOp::ConditionalRequirement( + (version_set_id, condition), + edges, + ) + } else { + DisplayOp::Requirement(version_set_id, edges) + }, indenter.push_level(), ) }) @@ -984,6 +1044,132 @@ impl<'i, I: Interner> DisplayUnsat<'i, I> { stack.extend(requirements); } } + DisplayOp::ConditionalRequirement((requirement, condition), edges) => { + debug_assert!(!edges.is_empty()); + + let installable = edges.iter().any(|&e| { + let (_, target) = graph.edge_endpoints(e).unwrap(); + installable_nodes.contains(&target) + }); + + let req = requirement.display(self.interner).to_string(); + let condition = condition.iter().map(|c| self.interner.display_condition(*c).to_string()).collect::>().join(" and "); + + let target_nx = graph.edge_endpoints(edges[0]).unwrap().1; + let missing = + edges.len() == 1 && graph[target_nx] == ConflictNode::UnresolvedDependency; + if missing { + // No candidates for requirement + if top_level { + writeln!(f, "{indent} the condition {condition} is true but no candidates were found for {req}.")?; + } else { + writeln!(f, "{indent}{req}, for which no candidates were found.",)?; + } + } else if installable { + // Package can be installed (only mentioned for top-level requirements) + if top_level { + writeln!( + f, + "{indent}due to the condition {condition}, {req} can be installed with any of the following options:" + )?; + } else { + writeln!(f, "{indent}{req}, which can be installed with any of the following options:")?; + } + + let children: Vec<_> = edges + .iter() + .filter(|&&e| { + installable_nodes.contains(&graph.edge_endpoints(e).unwrap().1) + }) + .map(|&e| { + ( + DisplayOp::Candidate(graph.edge_endpoints(e).unwrap().1), + indenter.push_level(), + ) + }) + .collect(); + + // TODO: this is an utterly ugly hack that should be burnt to ashes + let mut deduplicated_children = Vec::new(); + let mut merged_and_seen = HashSet::new(); + for child in children { + let (DisplayOp::Candidate(child_node), _) = child else { + unreachable!() + }; + let solvable_id = graph[child_node].solvable_or_root(); + let Some(solvable_id) = solvable_id.solvable() else { + continue; + }; + + let merged = self.merged_candidates.get(&solvable_id); + + // Skip merged stuff that we have already seen + if merged_and_seen.contains(&solvable_id) { + continue; + } + + if let Some(merged) = merged { + merged_and_seen.extend(merged.ids.iter().copied()) + } + + deduplicated_children.push(child); + } + + if !deduplicated_children.is_empty() { + deduplicated_children[0].1.set_last(); + } + + stack.extend(deduplicated_children); + } else { + // Package cannot be installed (the conflicting requirement is further down + // the tree) + if top_level { + writeln!(f, "{indent}The condition {condition} is true but {req} cannot be installed because there are no viable options:")?; + } else { + writeln!(f, "{indent}{req}, which cannot be installed because there are no viable options:")?; + } + + let children: Vec<_> = edges + .iter() + .map(|&e| { + ( + DisplayOp::Candidate(graph.edge_endpoints(e).unwrap().1), + indenter.push_level(), + ) + }) + .collect(); + + // TODO: this is an utterly ugly hack that should be burnt to ashes + let mut deduplicated_children = Vec::new(); + let mut merged_and_seen = HashSet::new(); + for child in children { + let (DisplayOp::Candidate(child_node), _) = child else { + unreachable!() + }; + let Some(solvable_id) = graph[child_node].solvable() else { + continue; + }; + let merged = self.merged_candidates.get(&solvable_id); + + // Skip merged stuff that we have already seen + if merged_and_seen.contains(&solvable_id) { + continue; + } + + if let Some(merged) = merged { + merged_and_seen.extend(merged.ids.iter().copied()) + } + + deduplicated_children.push(child); + } + + if !deduplicated_children.is_empty() { + deduplicated_children[0].1.set_last(); + } + + stack.extend(deduplicated_children); + } + } } } @@ -1020,6 +1206,7 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { let conflict = match e.weight() { ConflictEdge::Requires(_) => continue, ConflictEdge::Conflict(conflict) => conflict, + ConflictEdge::ConditionalRequires(_, _) => continue, }; // The only possible conflict at the root level is a Locked conflict @@ -1052,3 +1239,39 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_indenter_without_top_level_indent() { + let indenter = Indenter::new(false); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), ""); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), "└─ "); + } + + #[test] + fn test_indenter_with_multiple_siblings() { + let indenter = Indenter::new(true); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), "└─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); + assert_eq!(indenter.get_indent(), " ├─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), " │ └─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), " │ └─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); + assert_eq!(indenter.get_indent(), " │ ├─ "); + } +} diff --git a/src/internal/id.rs b/src/internal/id.rs index 47fe226..e3b160a 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -46,6 +46,12 @@ impl ArenaId for StringId { #[cfg_attr(feature = "serde", serde(transparent))] pub struct VersionSetId(pub u32); +impl From<(VersionSetId, Option)> for VersionSetId { + fn from((id, _): (VersionSetId, Option)) -> Self { + id + } +} + impl ArenaId for VersionSetId { fn from_usize(x: usize) -> Self { Self(x as u32) diff --git a/src/lib.rs b/src/lib.rs index 575c678..7e5c78e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,8 @@ pub use internal::{ mapping::Mapping, }; use itertools::Itertools; -pub use requirement::Requirement; +use requirement::Condition; +pub use requirement::{ConditionalRequirement, Requirement}; pub use solver::{Problem, Solver, SolverCache, UnsolvableOrCancelled}; /// An object that is used by the solver to query certain properties of @@ -73,6 +74,9 @@ pub trait Interner { /// user-friendly way. fn display_name(&self, name: NameId) -> impl Display + '_; + /// Returns an object that can used to display a [`Condition`] where a condition is either a [`Extra(StringId)`] or a [`VersionSetId`] + fn display_condition(&self, condition: Condition) -> impl Display + '_; + /// Returns an object that can be used to display the given version set in a /// user-friendly way. /// @@ -206,7 +210,7 @@ pub struct KnownDependencies { feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty") )] - pub requirements: Vec, + pub requirements: Vec, /// Defines additional constraints on packages that may or may not be part /// of the solution. Different from `requirements`, packages in this set diff --git a/src/requirement.rs b/src/requirement.rs index 244ec48..1f1b5bc 100644 --- a/src/requirement.rs +++ b/src/requirement.rs @@ -1,7 +1,115 @@ -use crate::{Interner, VersionSetId, VersionSetUnionId}; +use crate::{Interner, StringId, VersionSetId, VersionSetUnionId}; use itertools::Itertools; use std::fmt::Display; +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum Condition { + /// A condition that must be met for the requirement to be active. + VersionSetId(VersionSetId), + /// An extra which if enabled, requires further dependencies to be met. + Extra(StringId), +} + +impl From for Condition { + fn from(value: VersionSetId) -> Self { + Condition::VersionSetId(value) + } +} + +impl From for Condition { + fn from(value: StringId) -> Self { + Condition::Extra(value) + } +} + +impl From for VersionSetId { + fn from(value: Condition) -> Self { + match value { + Condition::VersionSetId(id) => id, + Condition::Extra(_) => panic!("Cannot convert Extra to VersionSetId"), + } + } +} + +/// Specifies a conditional requirement, where the requirement is only active when the condition is met. +#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ConditionalRequirement { + /// The conditions that must be met for the requirement to be active. + pub conditions: Vec, + /// The requirement that is only active when the condition is met. + pub requirement: Requirement, +} + +impl ConditionalRequirement { + /// Creates a new conditional requirement. + pub fn new(conditions: Vec, requirement: Requirement) -> Self { + Self { + conditions, + requirement, + } + } + /// Returns the version sets that satisfy the requirement. + pub fn requirement_version_sets<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator + 'i { + self.requirement.version_sets(interner) + } + + /// Returns the version sets that satisfy the requirement, along with the condition that must be met. + pub fn version_sets_with_condition<'i>( + &'i self, + interner: &'i impl Interner, + ) -> impl Iterator)> + 'i { + self.requirement + .version_sets(interner) + .map(move |vs| (vs, self.conditions.clone())) + } + + /// Returns the condition and requirement. + pub fn into_condition_and_requirement(self) -> (Vec, Requirement) { + (self.conditions, self.requirement) + } +} + +impl From for ConditionalRequirement { + fn from(value: Requirement) -> Self { + Self { + conditions: vec![], + requirement: value, + } + } +} + +impl From for ConditionalRequirement { + fn from(value: VersionSetId) -> Self { + Self { + conditions: vec![], + requirement: value.into(), + } + } +} + +impl From for ConditionalRequirement { + fn from(value: VersionSetUnionId) -> Self { + Self { + conditions: vec![], + requirement: value.into(), + } + } +} + +impl From<(VersionSetId, Vec)> for ConditionalRequirement { + fn from((requirement, conditions): (VersionSetId, Vec)) -> Self { + Self { + conditions, + requirement: requirement.into(), + } + } +} + /// Specifies the dependency of a solvable on a set of version sets. #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] diff --git a/src/snapshot.rs b/src/snapshot.rs index 0b8b6d2..2ca6889 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -15,8 +15,7 @@ use ahash::HashSet; use futures::FutureExt; use crate::{ - internal::arena::ArenaId, Candidates, Dependencies, DependencyProvider, Interner, Mapping, - NameId, Requirement, SolvableId, SolverCache, StringId, VersionSetId, VersionSetUnionId, + internal::arena::ArenaId, requirement::Condition, Candidates, Dependencies, DependencyProvider, Interner, Mapping, NameId, Requirement, SolvableId, SolverCache, StringId, VersionSetId, VersionSetUnionId }; /// A single solvable in a [`DependencySnapshot`]. @@ -220,7 +219,24 @@ impl DependencySnapshot { } } - for &requirement in deps.requirements.iter() { + for req in deps.requirements.iter() { + let (conditions, requirement) = req.clone().into_condition_and_requirement(); + + for condition in conditions { + match condition { + Condition::Extra(string_id) => { + if seen.insert(Element::String(string_id)) { + queue.push_back(Element::String(string_id)); + } + } + Condition::VersionSetId(version_set_id) => { + if seen.insert(Element::VersionSet(version_set_id)) { + queue.push_back(Element::VersionSet(version_set_id)); + } + } + } + } + match requirement { Requirement::Single(version_set) => { if seen.insert(Element::VersionSet(version_set)) { @@ -429,6 +445,13 @@ impl<'s> Interner for SnapshotProvider<'s> { self.string(string_id) } + fn display_condition(&self, condition: Condition) -> impl Display + '_ { + match condition { + Condition::Extra(string_id) => format!("{}", self.display_string(string_id)), + Condition::VersionSetId(version_set_id) => format!("{} {}", self.display_name(self.version_set_name(version_set_id)), self.display_version_set(version_set_id)), + } + } + fn version_set_name(&self, version_set: VersionSetId) -> NameId { self.version_set(version_set).name } diff --git a/src/solver/clause.rs b/src/solver/clause.rs index f034130..5b4649e 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -6,12 +6,14 @@ use std::{ }; use elsa::FrozenMap; +use itertools::Itertools; use crate::{ internal::{ arena::{Arena, ArenaId}, id::{ClauseId, LearntClauseId, StringId, VersionSetId}, }, + requirement::Condition, solver::{ decision_map::DecisionMap, decision_tracker::DecisionTracker, variable_map::VariableMap, VariableId, @@ -46,7 +48,7 @@ use crate::{ /// limited set of clauses. There are thousands of clauses for a particular /// dependency resolution problem, and we try to keep the [`Clause`] enum small. /// A naive implementation would store a `Vec`. -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Debug)] pub(crate) enum Clause { /// An assertion that the root solvable must be installed /// @@ -77,6 +79,10 @@ pub(crate) enum Clause { /// /// In SAT terms: (¬A ∨ ¬B) Constrains(VariableId, VariableId, VersionSetId), + /// In SAT terms: (¬A ∨ (¬C1 v ~C2 v ~C3 v ... v ~Cn) ∨ B1 ∨ B2 ∨ ... ∨ B99), where A is the solvable, + /// C1 to Cn are the conditions, and B1 to B99 represent the possible candidates for + /// the provided [`Requirement`]. + Conditional(VariableId, Vec<(VariableId, Condition)>, Requirement), /// Forbids the package on the right-hand side /// /// Note that the package on the left-hand side is not part of the clause, @@ -230,6 +236,40 @@ impl Clause { ) } + fn conditional( + parent_id: VariableId, + requirement: Requirement, + condition_variables: Vec<(VariableId, Condition)>, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Self, Option<[Literal; 2]>, bool) { + assert_ne!(decision_tracker.assigned_value(parent_id), Some(false)); + let mut requirement_candidates = requirement_candidates.into_iter(); + + let requirement_literal = if condition_variables.iter().all(|condition_variable| { + decision_tracker.assigned_value(condition_variable.0) == Some(true) + }) { + // then all of the conditions are true, so we can require the requirement + requirement_candidates + .find(|&id| decision_tracker.assigned_value(id) != Some(false)) + .map(|id| id.positive()) + } else { + None + }; + + ( + Clause::Conditional(parent_id, condition_variables.clone(), requirement), + Some([ + parent_id.negative(), + requirement_literal.unwrap_or(condition_variables.first().unwrap().0.negative()), + ]), + requirement_literal.is_none() + && condition_variables.iter().all(|condition_variable| { + decision_tracker.assigned_value(condition_variable.0) == Some(true) + }), + ) + } + /// Tries to fold over all the literals in the clause. /// /// This function is useful to iterate, find, or filter the literals in a @@ -248,10 +288,10 @@ impl Clause { where F: FnMut(C, Literal) -> ControlFlow, { - match *self { + match self { Clause::InstallRoot => unreachable!(), Clause::Excluded(solvable, _) => visit(init, solvable.negative()), - Clause::Learnt(learnt_id) => learnt_clauses[learnt_id] + Clause::Learnt(learnt_id) => learnt_clauses[*learnt_id] .iter() .copied() .try_fold(init, visit), @@ -267,11 +307,22 @@ impl Clause { .into_iter() .try_fold(init, visit), Clause::ForbidMultipleInstances(s1, s2, _) => { - [s1.negative(), s2].into_iter().try_fold(init, visit) + [s1.negative(), *s2].into_iter().try_fold(init, visit) } Clause::Lock(_, s) => [s.negative(), VariableId::root().negative()] .into_iter() .try_fold(init, visit), + Clause::Conditional(package_id, condition_variables, requirement) => { + iter::once(package_id.negative()) + .chain(condition_variables.iter().map(|c| c.0.negative())) + .chain( + requirements_to_sorted_candidates[&requirement] + .iter() + .flatten() + .map(|&s| s.positive()), + ) + .try_fold(init, visit) + } } } @@ -306,7 +357,7 @@ impl Clause { interner: &'i I, ) -> ClauseDisplay<'i, I> { ClauseDisplay { - kind: *self, + kind: self.clone(), variable_map, interner, } @@ -419,6 +470,33 @@ impl WatchedLiterals { (Self::from_kind_and_initial_watches(watched_literals), kind) } + /// Shorthand method to construct a [Clause::Conditional] without requiring + /// complicated arguments. + /// + /// The returned boolean value is true when adding the clause resulted in a + /// conflict. + pub fn conditional( + package_id: VariableId, + requirement: Requirement, + condition_variables: Vec<(VariableId, Condition)>, + decision_tracker: &DecisionTracker, + requirement_candidates: impl IntoIterator, + ) -> (Option, bool, Clause) { + let (kind, watched_literals, conflict) = Clause::conditional( + package_id, + requirement, + condition_variables, + decision_tracker, + requirement_candidates, + ); + + ( + WatchedLiterals::from_kind_and_initial_watches(watched_literals), + conflict, + kind, + ) + } + fn from_kind_and_initial_watches(watched_literals: Option<[Literal; 2]>) -> Option { let watched_literals = watched_literals?; debug_assert!(watched_literals[0] != watched_literals[1]); @@ -558,7 +636,7 @@ pub(crate) struct ClauseDisplay<'i, I: Interner> { impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.kind { + match &self.kind { Clause::InstallRoot => write!(f, "InstallRoot"), Clause::Excluded(variable, reason) => { write!( @@ -566,7 +644,7 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { "Excluded({}({:?}), {})", variable.display(self.variable_map, self.interner), variable, - self.interner.display_string(reason) + self.interner.display_string(*reason) ) } Clause::Learnt(learnt_id) => write!(f, "Learnt({learnt_id:?})"), @@ -587,7 +665,7 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { v1, v2.display(self.variable_map, self.interner), v2, - self.interner.display_version_set(version_set_id) + self.interner.display_version_set(*version_set_id) ) } Clause::ForbidMultipleInstances(v1, v2, name) => { @@ -598,7 +676,7 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { v1, v2.variable().display(self.variable_map, self.interner), v2, - self.interner.display_name(name) + self.interner.display_name(*name) ) } Clause::Lock(locked, other) => { @@ -611,6 +689,19 @@ impl<'i, I: Interner> Display for ClauseDisplay<'i, I> { other, ) } + Clause::Conditional(package_id, condition_variables, requirement) => { + write!( + f, + "Conditional({}({:?}), {}, {})", + package_id.display(self.variable_map, self.interner), + package_id, + condition_variables + .iter() + .map(|v| v.0.display(self.variable_map, self.interner)) + .join(", "), + requirement.display(self.interner), + ) + } } } } @@ -671,17 +762,11 @@ mod test { clause.as_ref().unwrap().watched_literals[0].variable(), parent ); - assert_eq!( - clause.unwrap().watched_literals[1].variable(), - candidate1.into() - ); + assert_eq!(clause.unwrap().watched_literals[1].variable(), candidate1); // No conflict, still one candidate available decisions - .try_add_decision( - Decision::new(candidate1.into(), false, ClauseId::from_usize(0)), - 1, - ) + .try_add_decision(Decision::new(candidate1, false, ClauseId::from_usize(0)), 1) .unwrap(); let (clause, conflict, _kind) = WatchedLiterals::requires( parent, @@ -696,13 +781,13 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate2.into() + candidate2 ); // Conflict, no candidates available decisions .try_add_decision( - Decision::new(candidate2.into(), false, ClauseId::install_root()), + Decision::new(candidate2, false, ClauseId::install_root()), 1, ) .unwrap(); @@ -719,7 +804,7 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate1.into() + candidate1 ); // Panic diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 8c0e026..7ff6ae8 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -9,7 +9,7 @@ use elsa::FrozenMap; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use indexmap::IndexMap; use itertools::Itertools; -use variable_map::VariableMap; +use variable_map::{SolvableOrStringId, VariableMap}; use watch_map::WatchMap; use crate::{ @@ -19,9 +19,11 @@ use crate::{ id::{ClauseId, LearntClauseId, NameId, SolvableId, SolvableOrRootId, VariableId}, mapping::Mapping, }, + requirement::{Condition, ConditionalRequirement}, runtime::{AsyncRuntime, NowOrNeverRuntime}, solver::binary_encoding::AtMostOnceTracker, - Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, VersionSetId, + Candidates, Dependencies, DependencyProvider, KnownDependencies, Requirement, StringId, + VersionSetId, }; mod binary_encoding; @@ -36,6 +38,12 @@ mod watch_map; #[derive(Default)] struct AddClauseOutput { new_requires_clauses: Vec<(VariableId, Requirement, ClauseId)>, + new_conditional_clauses: Vec<( + VariableId, + Vec<(VariableId, Condition)>, + Requirement, + ClauseId, + )>, conflicting_clauses: Vec, negative_assertions: Vec<(VariableId, ClauseId)>, clauses_to_watch: Vec, @@ -51,7 +59,7 @@ struct AddClauseOutput { /// This struct follows the builder pattern and can have its fields set by one /// of the available setter methods. pub struct Problem { - requirements: Vec, + requirements: Vec, constraints: Vec, soft_requirements: S, } @@ -80,7 +88,7 @@ impl> Problem { /// /// Returns the [`Problem`] for further mutation or to pass to /// [`Solver::solve`]. - pub fn requirements(self, requirements: Vec) -> Self { + pub fn requirements(self, requirements: Vec) -> Self { Self { requirements, ..self @@ -150,6 +158,11 @@ pub struct Solver { pub(crate) clauses: Clauses, requires_clauses: IndexMap, ahash::RandomState>, + conditional_clauses: IndexMap< + (VariableId, Vec<(VariableId, Condition)>), + Vec<(Requirement, ClauseId)>, + ahash::RandomState, + >, watches: WatchMap, /// A mapping from requirements to the variables that represent the @@ -172,7 +185,7 @@ pub struct Solver { decision_tracker: DecisionTracker, /// The [`Requirement`]s that must be installed as part of the solution. - root_requirements: Vec, + root_requirements: Vec, /// Additional constraints imposed by the root. root_constraints: Vec, @@ -200,6 +213,7 @@ impl Solver { clauses: Clauses::default(), variable_map: VariableMap::default(), requires_clauses: Default::default(), + conditional_clauses: Default::default(), requirement_to_sorted_candidates: FrozenMap::default(), watches: WatchMap::new(), negative_assertions: Default::default(), @@ -213,7 +227,6 @@ impl Solver { clauses_added_for_solvable: Default::default(), forbidden_clauses_added: Default::default(), name_activity: Default::default(), - activity_add: 1.0, activity_decay: 0.95, } @@ -280,6 +293,7 @@ impl Solver { clauses: self.clauses, variable_map: self.variable_map, requires_clauses: self.requires_clauses, + conditional_clauses: self.conditional_clauses, requirement_to_sorted_candidates: self.requirement_to_sorted_candidates, watches: self.watches, negative_assertions: self.negative_assertions, @@ -364,7 +378,9 @@ impl Solver { ); for additional in problem.soft_requirements { - let additional_var = self.variable_map.intern_solvable(additional); + let additional_var = self + .variable_map + .intern_solvable_or_string(additional.into()); if self .decision_tracker @@ -660,6 +676,16 @@ impl Solver { .or_default() .push((requirement, clause_id)); } + + for (solvable_id, condition_variables, requirement, clause_id) in + output.new_conditional_clauses + { + self.conditional_clauses + .entry((solvable_id, condition_variables)) + .or_default() + .push((requirement, clause_id)); + } + self.negative_assertions .append(&mut output.negative_assertions); @@ -695,7 +721,7 @@ impl Solver { fn resolve_dependencies(&mut self, mut level: u32) -> Result { loop { // Make a decision. If no decision could be made it means the problem is - // satisfyable. + // satisfiable. let Some((candidate, required_by, clause_id)) = self.decide() else { break; }; @@ -767,8 +793,36 @@ impl Solver { } let mut best_decision: Option = None; - for (&solvable_id, requirements) in self.requires_clauses.iter() { + + // Chain together the requires_clauses and conditional_clauses iterations + let requires_iter = self + .requires_clauses + .iter() + .map(|(&solvable_id, requirements)| { + ( + solvable_id, + None, + requirements + .iter() + .map(|(r, c)| (*r, *c)) + .collect::>(), + ) + }); + + let conditional_iter = + self.conditional_clauses + .iter() + .map(|((solvable_id, condition), clauses)| { + ( + *solvable_id, + Some(condition.clone()), + clauses.iter().map(|(r, c)| (*r, *c)).collect::>(), + ) + }); + + for (solvable_id, condition, requirements) in requires_iter.chain(conditional_iter) { let is_explicit_requirement = solvable_id == VariableId::root(); + if let Some(best_decision) = &best_decision { // If we already have an explicit requirement, there is no need to evaluate // non-explicit requirements. @@ -782,11 +836,29 @@ impl Solver { continue; } - for (deps, clause_id) in requirements.iter() { + // For conditional clauses, check that at least one conditional variable is true + if let Some(condition_variable) = condition { + // Check if any candidate that matches the condition's version set is installed + let condition_met = + condition_variable + .iter() + .all(|(condition_variable, _)| { + self.decision_tracker.assigned_value(*condition_variable) == Some(true) + }); + + // If the condition is not met, skip this requirement entirely + if !condition_met { + continue; + } + } + + for (requirement, clause_id) in requirements { let mut candidate = ControlFlow::Break(()); // Get the candidates for the individual version sets. - let version_set_candidates = &self.requirement_to_sorted_candidates[deps]; + let version_set_candidates = &self.requirement_to_sorted_candidates[&requirement]; + + let version_sets = requirement.version_sets(self.provider()); // Iterate over all version sets in the requirement and find the first version // set that we can act on, or if a single candidate (from any version set) makes @@ -795,10 +867,7 @@ impl Solver { // NOTE: We zip the version sets from the requirements and the variables that we // previously cached. This assumes that the order of the version sets is the // same in both collections. - for (version_set, candidates) in deps - .version_sets(self.provider()) - .zip(version_set_candidates) - { + for (version_set, candidates) in version_sets.zip(version_set_candidates) { // Find the first candidate that is not yet assigned a value or find the first // value that makes this clause true. candidate = candidates.iter().try_fold( @@ -875,7 +944,7 @@ impl Solver { candidate_count, package_activity, ))) => { - let decision = (candidate, solvable_id, *clause_id); + let decision = (candidate, solvable_id, clause_id); best_decision = Some(match &best_decision { None => PossibleDecision { is_explicit_requirement, @@ -1029,7 +1098,7 @@ impl Solver { if level == 1 { for decision in self.decision_tracker.stack() { let clause_id = decision.derived_from; - let clause = self.clauses.kinds[clause_id.to_usize()]; + let clause = &self.clauses.kinds[clause_id.to_usize()]; let level = self.decision_tracker.level(decision.variable); let action = if decision.value { "install" } else { "forbid" }; @@ -1220,12 +1289,12 @@ impl Solver { // Assertions derived from learnt rules for learn_clause_idx in 0..self.learnt_clause_ids.len() { let clause_id = self.learnt_clause_ids[learn_clause_idx]; - let clause = self.clauses.kinds[clause_id.to_usize()]; + let clause = &self.clauses.kinds[clause_id.to_usize()]; let Clause::Learnt(learnt_index) = clause else { unreachable!(); }; - let literals = &self.learnt_clauses[learnt_index]; + let literals = &self.learnt_clauses[*learnt_index]; if literals.len() > 1 { continue; } @@ -1519,7 +1588,7 @@ async fn add_clauses_for_solvables( RequirementCandidateVariables, ahash::RandomState, >, - root_requirements: &[Requirement], + root_requirements: &[ConditionalRequirement], root_constraints: &[VersionSetId], ) -> Result> { let mut output = AddClauseOutput::default(); @@ -1534,6 +1603,8 @@ async fn add_clauses_for_solvables( SortedCandidates { solvable_id: SolvableOrRootId, requirement: Requirement, + version_set_conditions: Vec<(SolvableId, VersionSetId)>, + string_conditions: Vec, candidates: Vec<&'i [SolvableId]>, }, NonMatchingCandidates { @@ -1611,11 +1682,11 @@ async fn add_clauses_for_solvables( // Allocate a variable for the solvable let variable = match solvable_id.solvable() { - Some(solvable_id) => variable_map.intern_solvable(solvable_id), + Some(solvable_id) => variable_map.intern_solvable_or_string(solvable_id.into()), None => variable_map.root(), }; - let (requirements, constrains) = match dependencies { + let (conditional_requirements, constrains) = match dependencies { Dependencies::Known(deps) => (deps.requirements, deps.constrains), Dependencies::Unknown(reason) => { // There is no information about the solvable's dependencies, so we add @@ -1637,17 +1708,27 @@ async fn add_clauses_for_solvables( } }; - for version_set_id in requirements + for (version_set_id, conditions) in conditional_requirements .iter() - .flat_map(|requirement| requirement.version_sets(cache.provider())) - .chain(constrains.iter().copied()) + .flat_map(|conditional_requirement| { + conditional_requirement.version_sets_with_condition(cache.provider()) + }) + .chain(constrains.iter().map(|&vs| (vs, Vec::new()))) { let dependency_name = cache.provider().version_set_name(version_set_id); if clauses_added_for_package.insert(dependency_name) { - tracing::trace!( - "┝━ Adding clauses for package '{}'", - cache.provider().display_name(dependency_name), - ); + if !conditions.is_empty() { + tracing::trace!( + "┝━ Adding conditional clauses for package '{}' with the conditions '{}'", + cache.provider().display_name(dependency_name), + conditions.iter().map(|c| cache.provider().display_condition(*c)).join(", "), + ); + } else { + tracing::trace!( + "┝━ Adding clauses for package '{}'", + cache.provider().display_name(dependency_name), + ); + } pending_futures.push( async move { @@ -1660,32 +1741,107 @@ async fn add_clauses_for_solvables( } .boxed_local(), ); + + for condition in conditions { + if let Condition::Extra(_) = condition { + continue; + } + let condition_name = + cache.provider().version_set_name(condition.into()); + if clauses_added_for_package.insert(condition_name) { + pending_futures.push( + async move { + let condition_candidates = + cache.get_or_cache_candidates(condition_name).await?; + Ok(TaskResult::Candidates { + name_id: condition_name, + package_candidates: condition_candidates, + }) + } + .boxed_local(), + ); + } + } } } - for requirement in requirements { + for conditional_requirement in conditional_requirements { // Find all the solvable that match for the given version set - pending_futures.push( - async move { - let candidates = futures::future::try_join_all( - requirement - .version_sets(cache.provider()) - .map(|version_set| { - cache.get_or_cache_sorted_candidates_for_version_set( - version_set, - ) - }), - ) - .await?; + let version_sets = + conditional_requirement.requirement_version_sets(cache.provider()); + let candidates = + futures::future::try_join_all(version_sets.map(|version_set| { + cache.get_or_cache_sorted_candidates_for_version_set(version_set) + })) + .await?; + + // Collect all non-Extra conditions and their candidates + let conditions: Vec<_> = + conditional_requirement.conditions.to_vec(); + let mut string_conditions = Vec::new(); + let mut version_set_conditions = Vec::new(); + let mut condition_candidates_futures = Vec::new(); + + // Process collected conditions + for condition in conditions { + match condition { + Condition::Extra(extra_id) => { + string_conditions.push(extra_id); + } + Condition::VersionSetId(version_set_id) => { + version_set_conditions.push(version_set_id); + condition_candidates_futures + .push(cache.get_or_cache_matching_candidates(version_set_id)); + } + } + } - Ok(TaskResult::SortedCandidates { - solvable_id, - requirement, - candidates, - }) + // Get all condition candidates in parallel + let condition_candidates = + futures::future::try_join_all(condition_candidates_futures).await?; + + // Create cartesian product of all condition candidates + let condition_combinations = condition_candidates + .iter() + .zip(version_set_conditions.iter()) + .map(|(cands, cond)| cands.iter().map(move |&c| (c, *cond))) + .multi_cartesian_product(); + + // Create a task for each combination + let condition_combinations: Vec<_> = condition_combinations.collect(); + if !condition_combinations.is_empty() { + for condition_combination in condition_combinations { + let candidates = candidates.clone(); + let string_conditions = string_conditions.clone(); + let requirement = conditional_requirement.requirement; + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement, + version_set_conditions: condition_combination, + string_conditions, + candidates, + }) + } + .boxed_local(), + ); } - .boxed_local(), - ); + } else { + // Add a task result for the condition + pending_futures.push( + async move { + Ok(TaskResult::SortedCandidates { + solvable_id, + requirement: conditional_requirement.requirement, + version_set_conditions: Vec::new(), + string_conditions: Vec::new(), + candidates, + }) + } + .boxed_local(), + ); + } } for version_set_id in constrains { @@ -1721,10 +1877,12 @@ async fn add_clauses_for_solvables( // If there is a locked solvable, forbid other solvables. if let Some(locked_solvable_id) = package_candidates.locked { - let locked_solvable_var = variable_map.intern_solvable(locked_solvable_id); + let locked_solvable_var = + variable_map.intern_solvable_or_string(locked_solvable_id.into()); for &other_candidate in candidates { if other_candidate != locked_solvable_id { - let other_candidate_var = variable_map.intern_solvable(other_candidate); + let other_candidate_var = + variable_map.intern_solvable_or_string(other_candidate.into()); let (watched_literals, kind) = WatchedLiterals::lock(locked_solvable_var, other_candidate_var); let clause_id = clauses.alloc(watched_literals, kind); @@ -1737,7 +1895,7 @@ async fn add_clauses_for_solvables( // Add a clause for solvables that are externally excluded. for (solvable, reason) in package_candidates.excluded.iter().copied() { - let solvable_var = variable_map.intern_solvable(solvable); + let solvable_var = variable_map.intern_solvable_or_string(solvable.into()); let (watched_literals, kind) = WatchedLiterals::exclude(solvable_var, reason); let clause_id = clauses.alloc(watched_literals, kind); @@ -1751,6 +1909,8 @@ async fn add_clauses_for_solvables( TaskResult::SortedCandidates { solvable_id, requirement, + version_set_conditions, + string_conditions, candidates, } => { tracing::trace!( @@ -1760,7 +1920,7 @@ async fn add_clauses_for_solvables( // Allocate a variable for the solvable let variable = match solvable_id.solvable() { - Some(solvable_id) => variable_map.intern_solvable(solvable_id), + Some(solvable_id) => variable_map.intern_solvable_or_string(solvable_id.into()), None => variable_map.root(), }; @@ -1772,7 +1932,7 @@ async fn add_clauses_for_solvables( .map(|&candidates| { candidates .iter() - .map(|&var| variable_map.intern_solvable(var)) + .map(|&var| variable_map.intern_solvable_or_string(var.into())) .collect() }) .collect(), @@ -1820,30 +1980,88 @@ async fn add_clauses_for_solvables( ); } - // Add the requirements clause - let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); - let (watched_literals, conflict, kind) = WatchedLiterals::requires( - variable, - requirement, - version_set_variables.iter().flatten().copied(), - decision_tracker, - ); - let has_watches = watched_literals.is_some(); - let clause_id = clauses.alloc(watched_literals, kind); + if !version_set_conditions.is_empty() { + let mut condition_variables = Vec::new(); + for (condition, condition_version_set_id) in version_set_conditions { + let condition_variable = variable_map + .intern_solvable_or_string(SolvableOrStringId::Solvable(condition)); + condition_variables.push(( + condition_variable, + Condition::VersionSetId(condition_version_set_id), + )); + } - if has_watches { - output.clauses_to_watch.push(clause_id); - } + for string_condition in string_conditions { + let condition_variable = variable_map.intern_solvable_or_string( + SolvableOrStringId::String(string_condition), + ); + condition_variables + .push((condition_variable, Condition::Extra(string_condition))); + } + + if !condition_variables.is_empty() { + // Add a condition clause + let (watched_literals, conflict, kind) = WatchedLiterals::conditional( + variable, + requirement, + condition_variables.clone(), + decision_tracker, + version_set_variables.iter().flatten().copied(), + ); - output - .new_requires_clauses - .push((variable, requirement, clause_id)); + // Add the conditional clause + let no_candidates = + candidates.iter().all(|candidates| candidates.is_empty()); - if conflict { - output.conflicting_clauses.push(clause_id); - } else if no_candidates { - // Add assertions for unit clauses (i.e. those with no matching candidates) - output.negative_assertions.push((variable, clause_id)); + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output.new_conditional_clauses.push(( + variable, + condition_variables, + requirement, + clause_id, + )); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } + } + } else { + let (watched_literals, conflict, kind) = WatchedLiterals::requires( + variable, + requirement, + version_set_variables.iter().flatten().copied(), + decision_tracker, + ); + + // Add the requirements clause + let no_candidates = candidates.iter().all(|candidates| candidates.is_empty()); + + let has_watches = watched_literals.is_some(); + let clause_id = clauses.alloc(watched_literals, kind); + + if has_watches { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((variable, requirement, clause_id)); + + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((variable, clause_id)); + } } } TaskResult::NonMatchingCandidates { @@ -1861,13 +2079,14 @@ async fn add_clauses_for_solvables( // Allocate a variable for the solvable let variable = match solvable_id.solvable() { - Some(solvable_id) => variable_map.intern_solvable(solvable_id), + Some(solvable_id) => variable_map.intern_solvable_or_string(solvable_id.into()), None => variable_map.root(), }; // Add forbidden clauses for the candidates for &forbidden_candidate in non_matching_candidates { - let forbidden_candidate_var = variable_map.intern_solvable(forbidden_candidate); + let forbidden_candidate_var = + variable_map.intern_solvable_or_string(forbidden_candidate.into()); let (state, conflict, kind) = WatchedLiterals::constrains( variable, forbidden_candidate_var, diff --git a/src/solver/variable_map.rs b/src/solver/variable_map.rs index 608ed68..150d390 100644 --- a/src/solver/variable_map.rs +++ b/src/solver/variable_map.rs @@ -7,9 +7,27 @@ use crate::{ arena::ArenaId, id::{SolvableOrRootId, VariableId}, }, - Interner, NameId, SolvableId, + Interner, NameId, SolvableId, StringId, }; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum SolvableOrStringId { + Solvable(SolvableId), + String(StringId), +} + +impl From for SolvableOrStringId { + fn from(solvable_id: SolvableId) -> Self { + Self::Solvable(solvable_id) + } +} + +impl From for SolvableOrStringId { + fn from(string_id: StringId) -> Self { + Self::String(string_id) + } +} + /// All variables in the solver are stored in a `VariableMap`. This map is used /// to keep track of the semantics of a variable, e.g. what a specific variable /// represents. @@ -21,7 +39,7 @@ pub struct VariableMap { next_id: usize, /// A map from solvable id to variable id. - solvable_to_variable: HashMap, + solvable_or_string_id_to_variable: HashMap, /// Records the origins of all variables. origins: HashMap, @@ -38,6 +56,9 @@ pub enum VariableOrigin { /// A variable that helps encode an at most one constraint. ForbidMultiple(NameId), + + /// The variable represents a specific string. + String(StringId), } impl Default for VariableMap { @@ -47,7 +68,7 @@ impl Default for VariableMap { Self { next_id: 1, // The first variable id is 1 because 0 is reserved for the root. - solvable_to_variable: HashMap::default(), + solvable_or_string_id_to_variable: HashMap::default(), origins, } } @@ -55,16 +76,30 @@ impl Default for VariableMap { impl VariableMap { /// Allocate a variable for a new variable or reuse an existing one. - pub fn intern_solvable(&mut self, solvable_id: SolvableId) -> VariableId { - match self.solvable_to_variable.entry(solvable_id) { + pub fn intern_solvable_or_string( + &mut self, + solvable_or_string_id: SolvableOrStringId, + ) -> VariableId { + match self + .solvable_or_string_id_to_variable + .entry(solvable_or_string_id) + { Entry::Occupied(entry) => *entry.get(), Entry::Vacant(entry) => { let id = self.next_id; self.next_id += 1; let variable_id = VariableId::from_usize(id); entry.insert(variable_id); - self.origins - .insert(variable_id, VariableOrigin::Solvable(solvable_id)); + match solvable_or_string_id { + SolvableOrStringId::Solvable(solvable_id) => { + self.origins + .insert(variable_id, VariableOrigin::Solvable(solvable_id)); + } + SolvableOrStringId::String(string_id) => { + self.origins + .insert(variable_id, VariableOrigin::String(string_id)); + } + } variable_id } } @@ -73,7 +108,7 @@ impl VariableMap { /// Allocate a variable for a solvable or the root. pub fn intern_solvable_or_root(&mut self, solvable_or_root_id: SolvableOrRootId) -> VariableId { match solvable_or_root_id.solvable() { - Some(solvable_id) => self.intern_solvable(solvable_id), + Some(solvable_id) => self.intern_solvable_or_string(solvable_id.into()), None => VariableId::root(), } } @@ -141,6 +176,9 @@ impl<'i, I: Interner> Display for VariableDisplay<'i, I> { VariableOrigin::ForbidMultiple(name) => { write!(f, "forbid-multiple({})", self.interner.display_name(name)) } + VariableOrigin::String(string_id) => { + write!(f, "{}", self.interner.display_string(string_id)) + } } } } diff --git a/src/utils/pool.rs b/src/utils/pool.rs index 2a3b6fe..e496858 100644 --- a/src/utils/pool.rs +++ b/src/utils/pool.rs @@ -42,8 +42,6 @@ pub struct Pool { /// Map from package names to the id of their interned counterpart pub(crate) string_to_ids: FrozenCopyMap, - - /// Interned match specs pub(crate) version_sets: Arena, /// Map from version set to the id of their interned counterpart @@ -123,6 +121,13 @@ impl Pool { &self.package_names[name_id] } + /// Returns the extra associated with the provided [`StringId`]. + /// + /// Panics if the extra is not found in the pool. + // pub fn resolve_extra(&self, package_id: NameId, extra_id: StringId) -> &str { + // &self.strings[self.extra_to_ids.get_copy(&(package_id, extra_id)).unwrap()] + // } + /// Returns the [`NameId`] associated with the specified name or `None` if /// the name has not previously been interned using /// [`Self::intern_package_name`]. diff --git a/tests/solver.rs b/tests/solver.rs index de15d8a..4fdb34f 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -18,13 +18,13 @@ use std::{ use ahash::HashMap; use indexmap::IndexMap; use insta::assert_snapshot; -use itertools::Itertools; +use itertools::{Itertools}; use resolvo::{ snapshot::{DependencySnapshot, SnapshotProvider}, utils::Pool, - Candidates, Dependencies, DependencyProvider, Interner, KnownDependencies, NameId, Problem, - Requirement, SolvableId, Solver, SolverCache, StringId, UnsolvableOrCancelled, VersionSetId, - VersionSetUnionId, + Candidates, ConditionalRequirement, Dependencies, DependencyProvider, Interner, + KnownDependencies, NameId, Problem, Requirement, SolvableId, Solver, SolverCache, StringId, + UnsolvableOrCancelled, VersionSetId, VersionSetUnionId, }; use tracing_test::traced_test; use version_ranges::Ranges; @@ -74,6 +74,14 @@ impl Pack { self } + fn with_extra(mut self, extra: impl Into, value: impl Into) -> Pack { + self.extra + .entry(extra.into()) + .or_default() + .push(value.into()); + self + } + fn offset(&self, version_offset: i32) -> Pack { let mut pack = *self; pack.version = pack.version.wrapping_add_signed(version_offset); @@ -113,19 +121,31 @@ impl FromStr for Pack { struct Spec { name: String, versions: Ranges, + /// c 1; if a 1 and b 1 (conditions are a, b) + conditions: Vec>, + /// a[b,c] 1; if d 1 and c 1 (extras are b, c) + extras: Vec, } impl Spec { - pub fn new(name: String, versions: Ranges) -> Self { - Self { name, versions } + pub fn new( + name: String, + versions: Ranges, + conditions: Vec>, + extras: Vec, + ) -> Self { + Self { + name, + versions, + conditions, + extras, + } } pub fn parse_union( spec: &str, ) -> impl Iterator::Err>> + '_ { - spec.split('|') - .map(str::trim) - .map(|dep| Spec::from_str(dep)) + spec.split('|').map(str::trim).map(Spec::from_str) } } @@ -133,11 +153,34 @@ impl FromStr for Spec { type Err = (); fn from_str(s: &str) -> Result { - let split = s.split(' ').collect::>(); - let name = split - .first() - .expect("spec does not have a name") - .to_string(); + let split = s.split_once("; if"); + + if split.is_none() { + let split = s.split(' ').collect::>(); + + // Extract feature name from brackets if present + let name_parts: Vec<_> = split[0].split('[').collect(); + let (name, extras) = if name_parts.len() > 1 { + // Has features in brackets + let extras = name_parts[1] + .trim_end_matches(']') + .split(',') + .map(|f| f.trim().to_string()) + .collect::>(); + (name_parts[0].to_string(), extras) + } else { + (name_parts[0].to_string(), vec![]) + }; + + let versions = version_range(split.get(1)); + return Ok(Spec::new(name, versions, None, extras)); + } + + let (spec, condition) = split.unwrap(); + + let condition = Spec::parse_union(condition).next().unwrap().unwrap(); + + let spec = Spec::from_str(spec).unwrap(); fn version_range(s: Option<&&str>) -> Ranges { if let Some(s) = s { @@ -156,9 +199,12 @@ impl FromStr for Spec { } } - let versions = version_range(split.get(1)); - - Ok(Spec::new(name, versions)) + Ok(Spec::new( + spec.name, + spec.versions, + Some(Box::new(condition)), + spec.extras, + )) } } @@ -187,6 +233,7 @@ struct BundleBoxProvider { struct BundleBoxPackageDependencies { dependencies: Vec>, constrains: Vec, + extras: HashMap>, } impl BundleBoxProvider { @@ -200,16 +247,28 @@ impl BundleBoxProvider { .expect("package missing") } - pub fn requirements>(&self, requirements: &[&str]) -> Vec { + pub fn requirements, Vec)>>( + &self, + requirements: &[&str], + ) -> Vec { requirements .iter() .map(|dep| Spec::from_str(dep).unwrap()) - .map(|spec| self.intern_version_set(&spec)) + .map(|spec| { + ( + self.intern_version_set(&spec), + spec.conditions + .iter() + .as_ref() + .map(|c| self.intern_version_set(c)), + spec.extras.iter().map(|e| e.to_string()), + ) + }) .map(From::from) .collect() } - pub fn parse_requirements(&self, requirements: &[&str]) -> Vec { + pub fn parse_requirements(&self, requirements: &[&str]) -> Vec { requirements .iter() .map(|deps| { @@ -236,14 +295,15 @@ impl BundleBoxProvider { .intern_version_set_union(specs.next().unwrap(), specs) } - pub fn from_packages(packages: &[(&str, u32, Vec<&str>)]) -> Self { + pub fn from_packages(packages: &[(&str, u32, Vec<&str>, &[(&str, &[&str])])]) -> Self { let mut result = Self::new(); - for (name, version, deps) in packages { - result.add_package(name, Pack::new(*version), deps, &[]); + for (name, version, deps, extras) in packages { + result.add_package(name, Pack::new(*version), deps, &[], extras); } result } + /// TODO: we should be able to set packages with extras as favored or excluded as well pub fn set_favored(&mut self, package_name: &str, version: u32) { self.favored .insert(package_name.to_owned(), Pack::new(version)); @@ -267,8 +327,9 @@ impl BundleBoxProvider { package_version: Pack, dependencies: &[&str], constrains: &[&str], + extras: &[(&str, &[&str])], ) { - self.pool.intern_package_name(package_name); + let package_id = self.pool.intern_package_name(package_name); let dependencies = dependencies .iter() @@ -276,6 +337,19 @@ impl BundleBoxProvider { .collect::, _>>() .unwrap(); + let extras = extras + .iter() + .map(|(key, values)| { + (self.pool.intern_string(key), { + values + .iter() + .map(|dep| Spec::parse_union(dep).collect()) + .collect::, _>>() + .unwrap() + }) + }) + .collect::>(); + let constrains = constrains .iter() .map(|dep| Spec::from_str(dep)) @@ -290,6 +364,7 @@ impl BundleBoxProvider { BundleBoxPackageDependencies { dependencies, constrains, + extras, }, ); } @@ -352,6 +427,17 @@ impl Interner for BundleBoxProvider { self.pool.resolve_package_name(name).clone() } + fn display_condition(&self, condition: Condition) -> impl Display + '_ { + match condition { + Condition::Extra(extra) => self.display_string(extra), + Condition::VersionSet(version_set) => format!( + "{} {}", + self.display_name(self.version_set_name(version_set)), + self.display_version_set(version_set) + ), + } + } + fn display_version_set(&self, version_set: VersionSetId) -> impl Display + '_ { self.pool.resolve_version_set(version_set).clone() } @@ -380,13 +466,22 @@ impl DependencyProvider for BundleBoxProvider { &self, candidates: &[SolvableId], version_set: VersionSetId, + extra: Option, inverse: bool, ) -> Vec { let range = self.pool.resolve_version_set(version_set); candidates .iter() .copied() - .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) == !inverse) + .map(|s| self.pool.resolve_solvable(s)) + .filter(|s| range.contains(&s.record) != inverse) + .filter(|s| { + if let Some(extra) = extra { + s.record.extra.contains(&self.pool.resolve_string(extra)) + } else { + true + } + }) .collect() } @@ -443,7 +538,11 @@ impl DependencyProvider for BundleBoxProvider { self.maybe_delay(Some(candidates)).await } - async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies { + async fn get_dependencies( + &self, + solvable: SolvableId, + extra: Option, + ) -> Dependencies { tracing::info!( "get dependencies for {}", self.pool @@ -485,6 +584,12 @@ impl DependencyProvider for BundleBoxProvider { .await; }; + let extra_deps = if let Some(extra) = extra { + deps.extras.get(&extra) + } else { + None + }; + let mut result = KnownDependencies { requirements: Vec::with_capacity(deps.dependencies.len()), constrains: Vec::with_capacity(deps.constrains.len()), @@ -502,18 +607,67 @@ impl DependencyProvider for BundleBoxProvider { .intern_version_set(first_name, first.versions.clone()); let requirement = if remaining_req_specs.len() == 0 { - first_version_set.into() + let mut conditions = vec![]; + if let Some(condition) = &first.condition { + conditions.push(self.intern_version_set(condition)); + } + ConditionalRequirement::new(conditions, first_version_set.into()) } else { - let other_version_sets = remaining_req_specs.map(|spec| { - self.pool.intern_version_set( + // Check if all specs have the same condition + let common_condition = first.condition.as_ref().map(|c| self.intern_version_set(c)); + + // Collect version sets for union + let mut version_sets = vec![first_version_set]; + for spec in remaining_req_specs { + // Verify condition matches + if spec.condition.as_ref().map(|c| self.intern_version_set(c)) + != common_condition + { + panic!("All specs in a union must have the same condition"); + } + + version_sets.push(self.pool.intern_version_set( self.pool.intern_package_name(&spec.name), spec.versions.clone(), - ) - }); + )); + } + + // Create union and wrap in conditional if needed + let union = self + .pool + .intern_version_set_union(version_sets[0], version_sets.into_iter().skip(1)); + + let mut conditions = vec![]; + if let Some(condition) = common_condition { + conditions.push(condition); + } + ConditionalRequirement::new(conditions, union.into()) + }; + + result.requirements.push(requirement); + } - self.pool - .intern_version_set_union(first_version_set, other_version_sets) - .into() + for req in extra_deps { + let mut remaining_req_specs = req.iter(); + + let first = remaining_req_specs + .next() + .expect("Dependency spec must have at least one constraint"); + + let first_name = self.pool.intern_package_name(&first.name); + let first_version_set = self + .pool + .intern_version_set(first_name, first.versions.clone()); + + let requirement = if remaining_req_specs.len() == 0 { + let mut conditions = vec![Condition::Extra(extra)]; + if let Some(condition) = &first.condition { + conditions.push(self.intern_version_set(condition)); + } + ConditionalRequirement::new(conditions, first_version_set.into()) + } else { + // TODO: Implement extra deps for union + todo!("extra deps for union not implemented") }; result.requirements.push(requirement); @@ -538,7 +692,7 @@ impl DependencyProvider for BundleBoxProvider { } /// Create a string from a [`Transaction`] -fn transaction_to_string(interner: &impl Interner, solvables: &Vec) -> String { +fn transaction_to_string(interner: &impl Interner, solvables: &[SolvableId]) -> String { use std::fmt::Write; let mut buf = String::new(); for solvable in solvables @@ -590,7 +744,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { let requirements = provider.parse_requirements(specs); let mut solver = Solver::new(provider).with_runtime(runtime); - let problem = Problem::new().requirements(requirements); + let problem = Problem::new().requirements(requirements.into_iter().map(|r| r.into()).collect()); match solver.solve(problem) { Ok(solvables) => transaction_to_string(solver.provider(), &solvables), Err(UnsolvableOrCancelled::Unsolvable(conflict)) => { @@ -703,12 +857,12 @@ fn test_resolve_with_concurrent_metadata_fetching() { #[test] fn test_resolve_with_conflict() { let provider = BundleBoxProvider::from_packages(&[ - ("asdf", 4, vec!["conflicting 1"]), - ("asdf", 3, vec!["conflicting 0"]), - ("efgh", 7, vec!["conflicting 0"]), - ("efgh", 6, vec!["conflicting 0"]), - ("conflicting", 1, vec![]), - ("conflicting", 0, vec![]), + ("asdf", 4, vec!["conflicting 1"], &[]), + ("asdf", 3, vec!["conflicting 0"], &[]), + ("efgh", 7, vec!["conflicting 0"], &[]), + ("efgh", 6, vec!["conflicting 0"], &[]), + ("conflicting", 1, vec![], &[]), + ("conflicting", 0, vec![], &[]), ]); let result = solve_snapshot(provider, &["asdf", "efgh"]); insta::assert_snapshot!(result); @@ -719,9 +873,9 @@ fn test_resolve_with_conflict() { #[traced_test] fn test_resolve_with_nonexisting() { let provider = BundleBoxProvider::from_packages(&[ - ("asdf", 4, vec!["b"]), - ("asdf", 3, vec![]), - ("b", 1, vec!["idontexist"]), + ("asdf", 4, vec!["b"], &[]), + ("asdf", 3, vec![], &[]), + ("b", 1, vec!["idontexist"], &[]), ]); let requirements = provider.requirements(&["asdf"]); let mut solver = Solver::new(provider); @@ -745,18 +899,25 @@ fn test_resolve_with_nested_deps() { "apache-airflow", 3, vec!["opentelemetry-api 2..4", "opentelemetry-exporter-otlp"], + &[], ), ( "apache-airflow", 2, vec!["opentelemetry-api 2..4", "opentelemetry-exporter-otlp"], + &[], + ), + ("apache-airflow", 1, vec![], &[]), + ("opentelemetry-api", 3, vec!["opentelemetry-sdk"], &[]), + ("opentelemetry-api", 2, vec![], &[]), + ("opentelemetry-api", 1, vec![], &[]), + ( + "opentelemetry-exporter-otlp", + 1, + vec!["opentelemetry-grpc"], + &[], ), - ("apache-airflow", 1, vec![]), - ("opentelemetry-api", 3, vec!["opentelemetry-sdk"]), - ("opentelemetry-api", 2, vec![]), - ("opentelemetry-api", 1, vec![]), - ("opentelemetry-exporter-otlp", 1, vec!["opentelemetry-grpc"]), - ("opentelemetry-grpc", 1, vec!["opentelemetry-api 1"]), + ("opentelemetry-grpc", 1, vec!["opentelemetry-api 1"], &[]), ]); let requirements = provider.requirements(&["apache-airflow"]); let mut solver = Solver::new(provider); @@ -781,8 +942,9 @@ fn test_resolve_with_unknown_deps() { Pack::new(3).with_unknown_deps(), &[], &[], + &[], ); - provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); + provider.add_package("opentelemetry-api", Pack::new(2), &[], &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); let mut solver = Solver::new(provider); let problem = Problem::new().requirements(requirements); @@ -809,12 +971,14 @@ fn test_resolve_and_cancel() { Pack::new(3).with_unknown_deps(), &[], &[], + &[], ); provider.add_package( "opentelemetry-api", Pack::new(2).cancel_during_get_dependencies(), &[], &[], + &[], ); let error = solve_unsat(provider, &["opentelemetry-api"]); insta::assert_snapshot!(error); @@ -825,7 +989,7 @@ fn test_resolve_and_cancel() { #[test] fn test_resolve_locked_top_level() { let mut provider = - BundleBoxProvider::from_packages(&[("asdf", 4, vec![]), ("asdf", 3, vec![])]); + BundleBoxProvider::from_packages(&[("asdf", 4, vec![], &[]), ("asdf", 3, vec![], &[])]); provider.set_locked("asdf", 3); let requirements = provider.requirements(&["asdf"]); @@ -845,9 +1009,9 @@ fn test_resolve_locked_top_level() { #[test] fn test_resolve_ignored_locked_top_level() { let mut provider = BundleBoxProvider::from_packages(&[ - ("asdf", 4, vec![]), - ("asdf", 3, vec!["fgh"]), - ("fgh", 1, vec![]), + ("asdf", 4, vec![], &[]), + ("asdf", 3, vec!["fgh"], &[]), + ("fgh", 1, vec![], &[]), ]); provider.set_locked("fgh", 1); @@ -869,10 +1033,10 @@ fn test_resolve_ignored_locked_top_level() { #[test] fn test_resolve_favor_without_conflict() { let mut provider = BundleBoxProvider::from_packages(&[ - ("a", 1, vec![]), - ("a", 2, vec![]), - ("b", 1, vec![]), - ("b", 2, vec![]), + ("a", 1, vec![], &[]), + ("a", 2, vec![], &[]), + ("b", 1, vec![], &[]), + ("b", 2, vec![], &[]), ]); provider.set_favored("a", 1); provider.set_favored("b", 1); @@ -888,12 +1052,12 @@ fn test_resolve_favor_without_conflict() { #[test] fn test_resolve_favor_with_conflict() { let mut provider = BundleBoxProvider::from_packages(&[ - ("a", 1, vec!["c 1"]), - ("a", 2, vec![]), - ("b", 1, vec!["c 1"]), - ("b", 2, vec!["c 2"]), - ("c", 1, vec![]), - ("c", 2, vec![]), + ("a", 1, vec!["c 1"], &[]), + ("a", 2, vec![], &[]), + ("b", 1, vec!["c 1"], &[]), + ("b", 2, vec!["c 2"], &[]), + ("c", 1, vec![], &[]), + ("c", 2, vec![], &[]), ]); provider.set_favored("a", 1); provider.set_favored("b", 1); @@ -909,8 +1073,10 @@ fn test_resolve_favor_with_conflict() { #[test] fn test_resolve_cyclic() { - let provider = - BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); + let provider = BundleBoxProvider::from_packages(&[ + ("a", 2, vec!["b 0..10"], &[]), + ("b", 5, vec!["a 2..4"], &[]), + ]); let requirements = provider.requirements(&["a 0..100"]); let mut solver = Solver::new(provider); let problem = Problem::new().requirements(requirements); @@ -926,15 +1092,15 @@ fn test_resolve_cyclic() { #[test] fn test_resolve_union_requirements() { let mut provider = BundleBoxProvider::from_packages(&[ - ("a", 1, vec![]), - ("b", 1, vec![]), - ("c", 1, vec!["a"]), - ("d", 1, vec!["b"]), - ("e", 1, vec!["a | b"]), + ("a", 1, vec![], &[]), + ("b", 1, vec![], &[]), + ("c", 1, vec!["a"], &[]), + ("d", 1, vec!["b"], &[]), + ("e", 1, vec!["a | b"], &[]), ]); // Make d conflict with a=1 - provider.add_package("f", 1.into(), &["b"], &["a 2"]); + provider.add_package("f", 1.into(), &["b"], &["a 2"], &["b"]); let result = solve_snapshot(provider, &["c | d", "e", "f"]); assert_snapshot!(result, @r###" @@ -1079,8 +1245,8 @@ fn test_unsat_constrains() { ("b", 42, vec![]), ]); - provider.add_package("c", 10.into(), &[], &["b 0..50"]); - provider.add_package("c", 8.into(), &[], &["b 0..50"]); + provider.add_package("c", 10.into(), &[], &["b 0..50"], &[]); + provider.add_package("c", 8.into(), &[], &["b 0..50"], &[]); let error = solve_unsat(provider, &["a", "c"]); insta::assert_snapshot!(error); } @@ -1095,8 +1261,8 @@ fn test_unsat_constrains_2() { ("b", 2, vec!["c 2"]), ]); - provider.add_package("c", 1.into(), &[], &["a 3"]); - provider.add_package("c", 2.into(), &[], &["a 3"]); + provider.add_package("c", 1.into(), &[], &["a 3"], &[]); + provider.add_package("c", 2.into(), &[], &["a 3"], &[]); let error = solve_unsat(provider, &["a"]); insta::assert_snapshot!(error); } @@ -1270,13 +1436,13 @@ fn test_solve_with_additional_with_constrains() { ("e", 1, vec!["c"]), ]); - provider.add_package("f", 1.into(), &[], &["c 2..3"]); - provider.add_package("g", 1.into(), &[], &["b 2..3"]); - provider.add_package("h", 1.into(), &[], &["b 1..2"]); - provider.add_package("i", 1.into(), &[], &[]); - provider.add_package("j", 1.into(), &["i"], &[]); - provider.add_package("k", 1.into(), &["i"], &[]); - provider.add_package("l", 1.into(), &["j", "k"], &[]); + provider.add_package("f", 1.into(), &[], &["c 2..3"], &[]); + provider.add_package("g", 1.into(), &[], &["b 2..3"], &[]); + provider.add_package("h", 1.into(), &[], &["b 1..2"], &[]); + provider.add_package("i", 1.into(), &[], &[], &[]); + provider.add_package("j", 1.into(), &["i"], &[], &[]); + provider.add_package("k", 1.into(), &["i"], &[], &[]); + provider.add_package("l", 1.into(), &["j", "k"], &[], &[]); let requirements = provider.requirements(&["a 0..10", "e"]); let constraints = provider.requirements(&["b 1..2", "c", "k 2..3"]); @@ -1429,6 +1595,364 @@ fn test_explicit_root_requirements() { "###); } +#[test] +#[traced_test] +fn test_conditional_requirements() { + let mut provider = BundleBoxProvider::new(); + + // Add packages + provider.add_package("a", 1.into(), &["b"], &[], &[]); // a depends on b + provider.add_package("b", 1.into(), &[], &[], &[]); // Simple package b + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + + // Create problem with both regular and conditional requirements + let requirements = provider.requirements(&["a", "c 1; if b 1..2"]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + "###); +} + +#[test] +#[traced_test] +fn test_conditional_requirements_not_met() { + let mut provider = BundleBoxProvider::new(); + provider.add_package("b", 1.into(), &[], &[], &[]); // Add b=1 as a candidate + provider.add_package("b", 2.into(), &[], &[], &[]); // Different version of b + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("a", 1.into(), &["b 2"], &[], &[]); // a depends on b=2 specifically + + // Create conditional requirement: if b=1 is installed, require c + let requirements = provider.requirements(&[ + "a", // Require package a + "c 1; if b 1", // If b=1 is installed, require c (note the exact version) + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (not b=1), c should not be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + "###); +} + +#[test] +fn test_nested_conditional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[], &[]); // Base package + provider.add_package("b", 1.into(), &[], &[], &[]); // First level conditional + provider.add_package("c", 1.into(), &[], &[], &[]); // Second level conditional + provider.add_package("d", 1.into(), &[], &[], &[]); // Third level conditional + + // Create nested conditional requirements using the parser + let requirements = provider.requirements(&[ + "a", // Require package a + "b 1; if a 1", // If a is installed, require b + "c 1; if b 1", // If b is installed, require c + "d 1; if c 1", // If c is installed, require d + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // All packages should be installed due to chain reaction + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + d=1 + "###); +} + +#[test] +fn test_multiple_conditions_same_package() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[], &[]); + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("c", 1.into(), &[], &[], &[]); + provider.add_package("target", 1.into(), &[], &[], &[]); + + // Create multiple conditions that all require the same package using the parser + let requirements = provider.requirements(&[ + "b", // Only require package b + "target 1; if a 1", // If a is installed, require target + "target 1; if b 1", // If b is installed, require target + "target 1; if c 1", // If c is installed, require target + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // b and target should be installed + insta::assert_snapshot!(result, @r###" + b=1 + target=1 + "###); +} + +#[test] +fn test_circular_conditional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Setup packages + provider.add_package("a", 1.into(), &[], &[], &[]); + provider.add_package("b", 1.into(), &[], &[], &[]); + + // Create circular conditional dependencies using the parser + let requirements = provider.requirements(&[ + "a", // Require package a + "b 1; if a 1", // If a is installed, require b + "a 1; if b 1", // If b is installed, require a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Both packages should be installed due to circular dependency + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + "###); +} + +#[test] +fn test_conditional_requirements_multiple_versions() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + provider.add_package("b", 4.into(), &[], &[], &[]); + provider.add_package("b", 5.into(), &[], &[], &[]); + + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("a", 1.into(), &["b 4..6"], &[], &[]); // a depends on b versions 4-5 + + // Create conditional requirement: if b=1..3 is installed, require c + let requirements = provider.requirements(&[ + "a", // Require package a + "c 1; if b 1..3", // If b is version 1-2, require c + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=4 is installed (not b 1..3), c should not be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=5 + "###); +} + +#[test] +fn test_conditional_requirements_multiple_versions_met() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + provider.add_package("b", 4.into(), &[], &[], &[]); + provider.add_package("b", 5.into(), &[], &[], &[]); + + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("c", 2.into(), &[], &[], &[]); // Version 2 of c + provider.add_package("c", 3.into(), &[], &[], &[]); // Version 3 of c + provider.add_package("a", 1.into(), &["b 1..3", "c 1..3; if b 1..3"], &[], &[]); // a depends on b 1-3 and conditionally on c 1-3 + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (within b 1..2), c should be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + c=2 + "###); +} + +#[test] +fn test_conditional_requirements_conflict() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + + // Package c has two versions with different dependencies + provider.add_package("c", 1.into(), &["d 1"], &[], &[]); // c v1 requires d v1 + provider.add_package("c", 2.into(), &["d 2"], &[], &[]); // c v2 requires d v2 + + // Package d has incompatible versions + provider.add_package("d", 1.into(), &[], &[], &[]); + provider.add_package("d", 2.into(), &[], &[], &[]); + + provider.add_package( + "a", + 1.into(), + &["b 1", "c 1; if b 1", "d 2", "c 2; if b 2"], + &[], + &[], + ); + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + + // This should fail to solve because: + // 1. When b=1 is chosen, it triggers the conditional requirement for c 1 + // 2. c 1 requires d 1, but a requires d 2 + // 3. d 1 and d 2 cannot be installed together + + let solved = solver + .solve(problem) + .map_err(|e| match e { + UnsolvableOrCancelled::Unsolvable(conflict) => { + conflict.display_user_friendly(&solver).to_string() + } + UnsolvableOrCancelled::Cancelled(_) => "kir".to_string(), + }) + .unwrap_err(); + + assert_snapshot!(solved, @r" + The following packages are incompatible + └─ a * cannot be installed because there are no viable options: + └─ a 1 would require + ├─ b >=1, <2, which can be installed with any of the following options: + │ └─ b 1 + ├─ d >=2, <3, which can be installed with any of the following options: + │ └─ d 2 + └─ c >=1, <2, which cannot be installed because there are no viable options: + └─ c 1 would require + └─ d >=1, <2, which cannot be installed because there are no viable options: + └─ d 1, which conflicts with the versions reported above. + "); +} + +/// In this test, the resolver installs the highest available version of b which is b 2 +/// However, the conditional requirement is that if b 1..2 is installed, require c +/// Since b 2 is installed, c should not be installed +#[test] +fn test_conditional_requirements_multiple_versions_not_met() { + let mut provider = BundleBoxProvider::new(); + + // Add multiple versions of package b + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("b", 3.into(), &[], &[], &[]); + provider.add_package("b", 4.into(), &[], &[], &[]); + provider.add_package("b", 5.into(), &[], &[], &[]); + + provider.add_package("c", 1.into(), &[], &[], &[]); // Simple package c + provider.add_package("c", 2.into(), &[], &[], &[]); // Version 2 of c + provider.add_package("c", 3.into(), &[], &[], &[]); // Version 3 of c + provider.add_package("a", 1.into(), &["b 1..3", "c 1..3; if b 1..2"], &[], &[]); // a depends on b 1-3 and conditionally on c 1-3 + + let requirements = provider.requirements(&[ + "a", // Require package a + ]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + // Since b=2 is installed (within b 1..2), c should be installed + insta::assert_snapshot!(result, @r###" + a=1 + b=2 + "###); +} + +#[test] +fn test_optional_dependencies() { + let mut provider = BundleBoxProvider::new(); + + // Add package a with base dependency on b and optional dependencies via features + provider.add_package( + "a", + 1.into(), + &["b 1"], + &[], + &[("feat1", &["c"]), ("feat2", &["d"])], + ); + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("c", 1.into(), &[], &[], &[]); + provider.add_package("d", 1.into(), &[], &[], &[]); + + // Request package a with both optional features enabled + let requirements = provider.requirements(&["a[feat2]"]); + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + d=1 + "###); +} + +#[test] +fn test_conditonal_requirements_with_extras() { + let mut provider = BundleBoxProvider::new(); + + // Package a has both optional dependencies (via features) and conditional dependencies + provider.add_package( + "a", + 1.into(), + &["b 1"], + &[], + &[("feat1", &["c"]), ("feat2", &["d"])], + ); + provider.add_package("b", 1.into(), &[], &[], &[]); + provider.add_package("b", 2.into(), &[], &[], &[]); + provider.add_package("c", 1.into(), &[], &[], &[]); + provider.add_package("d", 1.into(), &[], &[], &[]); + provider.add_package("e", 1.into(), &[], &[], &[]); + + // Request package a with feat1 enabled, which will pull in c + // This should trigger the conditional requirement on e + let requirements = provider.requirements(&["a[feat1]", "e 1; if c 1"]); + + let mut solver = Solver::new(provider); + let problem = Problem::new().requirements(requirements); + let solved = solver.solve(problem).unwrap(); + let result = transaction_to_string(solver.provider(), &solved); + insta::assert_snapshot!(result, @r###" + a=1 + b=1 + c=1 + e=1 + "###); +} + #[cfg(feature = "serde")] fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef) { let file = std::io::BufWriter::new(std::fs::File::create(destination.as_ref()).unwrap()); diff --git a/tools/solve-snapshot/src/main.rs b/tools/solve-snapshot/src/main.rs index 901996c..3629eaf 100644 --- a/tools/solve-snapshot/src/main.rs +++ b/tools/solve-snapshot/src/main.rs @@ -128,7 +128,8 @@ fn main() { let start = Instant::now(); - let problem = Problem::default().requirements(requirements); + let problem = + Problem::default().requirements(requirements.into_iter().map(Into::into).collect()); let mut solver = Solver::new(provider); let mut records = None; let mut error = None;