diff --git a/Cargo.toml b/Cargo.toml index d1050309b..36a342aec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,7 +108,7 @@ regex = "1.10.4" reqwest = { version = "0.12.3", default-features = false } reqwest-middleware = "0.3.0" reqwest-retry = "0.5.0" -resolvo = { version = "0.4.1" } +resolvo = { version = "0.5.0" } retry-policies = { version = "0.3.0", default-features = false } rmp-serde = { version = "1.2.0" } rstest = { version = "0.19.0" } diff --git a/crates/rattler_solve/src/lib.rs b/crates/rattler_solve/src/lib.rs index 1d53f7419..7e4f7c9ec 100644 --- a/crates/rattler_solve/src/lib.rs +++ b/crates/rattler_solve/src/lib.rs @@ -131,6 +131,11 @@ pub struct SolverTask { /// The specs we want to solve pub specs: Vec, + /// Additional constraints that should be satisfied by the solver. + /// Packages included in the `constraints` are not necessarily + /// installed, but they must be satisfied by the solution. + pub constraints: Vec, + /// The timeout after which the solver should stop pub timeout: Option, @@ -156,6 +161,7 @@ impl<'r, I: IntoIterator> FromIterator pinned_packages: Vec::new(), virtual_packages: Vec::new(), specs: Vec::new(), + constraints: Vec::new(), timeout: None, channel_priority: ChannelPriority::default(), exclude_newer: None, diff --git a/crates/rattler_solve/src/libsolv_c/mod.rs b/crates/rattler_solve/src/libsolv_c/mod.rs index ae3dfd4eb..dc6a9de58 100644 --- a/crates/rattler_solve/src/libsolv_c/mod.rs +++ b/crates/rattler_solve/src/libsolv_c/mod.rs @@ -234,6 +234,11 @@ impl super::SolverImpl for Solver { goal.install(id, false); } + for spec in task.constraints { + let id = pool.intern_matchspec(&spec); + goal.install(id, true); + } + // Construct a solver and solve the problems in the queue let mut solver = pool.create_solver(); solver.set_flag(SolverFlag::allow_uninstall(), true); diff --git a/crates/rattler_solve/src/resolvo/mod.rs b/crates/rattler_solve/src/resolvo/mod.rs index b2e85d6a8..5df905b06 100644 --- a/crates/rattler_solve/src/resolvo/mod.rs +++ b/crates/rattler_solve/src/resolvo/mod.rs @@ -568,11 +568,21 @@ impl super::SolverImpl for Solver { }) .collect(); + let root_constraints = task + .constraints + .iter() + .map(|spec| { + let (name, spec) = spec.clone().into_nameless(); + let name = name.expect("cannot use matchspec without a name"); + let name_id = provider.pool.intern_package_name(name.as_normalized()); + provider.pool.intern_version_set(name_id, spec.into()) + }) + .collect(); + // Construct a solver and solve the problems in the queue let mut solver = LibSolvRsSolver::new(provider); - let solvables = solver - .solve(root_requirements) - .map_err(|unsolvable_or_cancelled| { + let solvables = solver.solve(root_requirements, root_constraints).map_err( + |unsolvable_or_cancelled| { match unsolvable_or_cancelled { UnsolvableOrCancelled::Unsolvable(problem) => { SolveError::Unsolvable(vec![problem @@ -583,7 +593,8 @@ impl super::SolverImpl for Solver { // put a generic message in here for now UnsolvableOrCancelled::Cancelled(_) => SolveError::Cancelled, } - })?; + }, + )?; // Get the resulting packages from the solver. let required_records = solvables diff --git a/crates/rattler_solve/tests/backends.rs b/crates/rattler_solve/tests/backends.rs index d9921085e..74bf1073e 100644 --- a/crates/rattler_solve/tests/backends.rs +++ b/crates/rattler_solve/tests/backends.rs @@ -540,6 +540,27 @@ macro_rules! solver_backend_tests { _ => panic!("expected a DuplicateRecord error"), } } + + #[test] + fn test_constraints() { + // There following package is provided as .tar.bz and as .conda in repodata.json + let mut operations = solve::<$T>( + dummy_channel_json_path(), + SimpleSolveTask { + specs: &["foobar"], + constraints: vec!["bors <=1", "nonexisting"], + ..SimpleSolveTask::default() + }, + ) + .unwrap(); + + // Sort operations by file name to make the test deterministic + operations.sort_by(|a, b| a.file_name.cmp(&b.file_name)); + + assert_eq!(operations.len(), 2); + assert_eq!(operations[0].file_name, "bors-1.0-bla_1.tar.bz2"); + assert_eq!(operations[1].file_name, "foobar-2.1-bla_1.tar.bz2"); + } }; } @@ -592,6 +613,7 @@ mod libsolv_c { virtual_packages: Vec::new(), available_packages: [libsolv_repodata], specs, + constraints: Vec::new(), pinned_packages: Vec::new(), timeout: None, channel_priority: ChannelPriority::default(), @@ -768,6 +790,7 @@ mod resolvo { #[derive(Default)] struct SimpleSolveTask<'a> { specs: &'a [&'a str], + constraints: Vec<&'a str>, installed_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, @@ -787,10 +810,17 @@ fn solve( .map(|m| MatchSpec::from_str(m, ParseStrictness::Lenient).unwrap()) .collect(); + let constraints = task + .constraints + .into_iter() + .map(|m| MatchSpec::from_str(m, ParseStrictness::Lenient).unwrap()) + .collect(); + let task = SolverTask { locked_packages: task.installed_packages, virtual_packages: task.virtual_packages, specs, + constraints, pinned_packages: task.pinned_packages, exclude_newer: task.exclude_newer, strategy: task.strategy, @@ -942,12 +972,11 @@ fn compare_solve_xtensor_xsimd() { }); } -fn solve_to_get_channel_of_spec( +fn solve_to_get_channel_of_spec( spec_str: &str, expected_channel: &str, repo_data: Vec<&SparseRepoData>, channel_priority: ChannelPriority, - use_resolvo: bool, ) { let spec = MatchSpec::from_str(spec_str, ParseStrictness::Lenient).unwrap(); let specs = vec![spec.clone()]; @@ -962,11 +991,7 @@ fn solve_to_get_channel_of_spec( ..SolverTask::from_iter(&available_packages) }; - let result = if use_resolvo { - rattler_solve::resolvo::Solver.solve(task).unwrap() - } else { - rattler_solve::libsolv_c::Solver.solve(task).unwrap() - }; + let result = T::default().solve(task).unwrap(); let record = result.iter().find(|record| { record.package_record.name.as_normalized() == spec.name.as_ref().unwrap().as_normalized() @@ -980,33 +1005,29 @@ fn channel_specific_requirement() { read_conda_forge_sparse_repo_data(), read_pytorch_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "conda-forge::pytorch-cpu", "https://conda.anaconda.org/conda-forge/", repodata.clone(), ChannelPriority::Strict, - true, ); - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "conda-forge::pytorch-cpu", "https://conda.anaconda.org/conda-forge/", repodata.clone(), ChannelPriority::Disabled, - true, ); - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch::pytorch-cpu", "https://conda.anaconda.org/pytorch/", repodata.clone(), ChannelPriority::Strict, - true, ); - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch::pytorch-cpu", "https://conda.anaconda.org/pytorch/", repodata, ChannelPriority::Disabled, - true, ); } @@ -1017,12 +1038,11 @@ fn channel_priority_strict() { read_conda_forge_sparse_repo_data(), read_pytorch_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch-cpu", "https://conda.anaconda.org/conda-forge/", repodata, ChannelPriority::Strict, - true, ); // Solve with pytorch as the first channel @@ -1030,12 +1050,11 @@ fn channel_priority_strict() { read_pytorch_sparse_repo_data(), read_conda_forge_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch-cpu", "https://conda.anaconda.org/pytorch/", repodata, ChannelPriority::Strict, - true, ); } @@ -1051,12 +1070,11 @@ fn channel_priority_strict_panic() { read_conda_forge_sparse_repo_data(), read_pytorch_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch-cpu=0.4.1=py36_cpu_1", "https://conda.anaconda.org/pytorch/", repodata, ChannelPriority::Strict, - true, ); } @@ -1066,15 +1084,15 @@ fn channel_priority_disabled() { read_conda_forge_sparse_repo_data(), read_pytorch_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch-cpu=0.4.1=py36_cpu_1", "https://conda.anaconda.org/pytorch/", repodata, ChannelPriority::Disabled, - true, ); } +#[cfg(feature = "libsolv_c")] #[test] #[should_panic( expected = "called `Result::unwrap()` on an `Err` value: Unsolvable([\"package \ @@ -1086,15 +1104,15 @@ fn channel_priority_strict_libsolv_c() { read_pytorch_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch-cpu=0.4.1=py36_cpu_1", "https://conda.anaconda.org/pytorch/", repodata, ChannelPriority::Strict, - false, ); } +#[cfg(feature = "libsolv_c")] #[test] fn channel_priority_disabled_libsolv_c() { let repodata = vec![ @@ -1102,11 +1120,10 @@ fn channel_priority_disabled_libsolv_c() { read_pytorch_sparse_repo_data(), ]; - solve_to_get_channel_of_spec( + solve_to_get_channel_of_spec::( "pytorch-cpu=0.4.1=py36_cpu_1", "https://conda.anaconda.org/pytorch/", repodata, ChannelPriority::Disabled, - false, ); } diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 2da51324f..ac69075ae 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -29,6 +29,19 @@ dependencies = [ "opaque-debug", ] +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -2975,10 +2988,11 @@ dependencies = [ [[package]] name = "resolvo" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d299d168910c5d71f3c0f5441abe38ca4a6ae21f70fae909bfc6bead28f6620f" +checksum = "e7b73dc355efbb88c372550b92bf17d36bf555ecf319a4783a5b8b7c34488bc5" dependencies = [ + "ahash", "bitvec", "elsa", "event-listener 5.3.0", @@ -4447,6 +4461,26 @@ dependencies = [ "zvariant", ] +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index 98c4bdb31..dfd1204c5 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -29,6 +29,7 @@ async def solve( channel_priority: ChannelPriority = ChannelPriority.Strict, exclude_newer: Optional[datetime.datetime] = None, strategy: SolveStrategy = "highest", + constraints: Optional[List[MatchSpec | str]] = None, ) -> List[RepoDataRecord]: """ Resolve the dependencies and return the `RepoDataRecord`s @@ -69,6 +70,9 @@ async def solve( * `"lowest-direct"`: Select the lowest compatible version for all direct dependencies but the highest compatible version of transitive dependencies. + constraints: Additional constraints that should be satisfied by the solver. + Packages included in the `constraints` are not necessarily installed, + but they must be satisfied by the solution. Returns: Resolved list of `RepoDataRecord`s. @@ -97,5 +101,11 @@ async def solve( if exclude_newer else None, strategy=strategy, + constraints=[ + constraint._match_spec if isinstance(constraint, MatchSpec) else PyMatchSpec(str(constraint), True) + for constraint in constraints + ] + if constraints is not None + else [], ) ] diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index 1960e292f..d957bdea1 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -37,6 +37,7 @@ pub fn py_solve( channels: Vec, platforms: Vec, specs: Vec, + constraints: Vec, gateway: PyGateway, locked_packages: Vec, pinned_packages: Vec, @@ -77,6 +78,7 @@ pub fn py_solve( .collect::>>()?, virtual_packages: virtual_packages.into_iter().map(Into::into).collect(), specs: specs.into_iter().map(Into::into).collect(), + constraints: constraints.into_iter().map(Into::into).collect(), timeout: timeout.map(std::time::Duration::from_micros), channel_priority: channel_priority.into(), exclude_newer, diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py index eac1f2d89..3f511dc5b 100644 --- a/py-rattler/tests/unit/test_solver.py +++ b/py-rattler/tests/unit/test_solver.py @@ -122,3 +122,19 @@ async def test_solve_channel_priority_disabled( == pytorch_channel.base_url ) assert len(solved_data) == 32 + +@pytest.mark.asyncio +async def test_solve_constraints(gateway: Gateway, dummy_channel: Channel) -> None: + solved_data = await solve( + [dummy_channel], + ["foobar"], + constraints=["bors <=1", "nonexisting"], + platforms=["linux-64"], + gateway=gateway, + ) + + assert isinstance(solved_data, list) + assert len(solved_data) == 2 + + assert solved_data[0].file_name == "foobar-2.1-bla_1.tar.bz2" + assert solved_data[1].file_name == "bors-1.0-bla_1.tar.bz2"