Skip to content

Commit

Permalink
feat: add constraints to solve (#713)
Browse files Browse the repository at this point in the history
Adds `constraints` to `SolverTask` in Rust and to the `solve` function
in Python.

Fixes #712
  • Loading branch information
baszalmstra committed Jun 3, 2024
1 parent 66c6f2a commit 18e094b
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ regex = "1.10.4"
reqwest = { version = "0.12.3", default-features = false }
reqwest-middleware = "0.3.0"
reqwest-retry = "0.5.0"
resolvo = { version = "0.4.1" }
resolvo = { version = "0.5.0" }
retry-policies = { version = "0.3.0", default-features = false }
rmp-serde = { version = "1.2.0" }
rstest = { version = "0.19.0" }
Expand Down
6 changes: 6 additions & 0 deletions crates/rattler_solve/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ pub struct SolverTask<TAvailablePackagesIterator> {
/// The specs we want to solve
pub specs: Vec<MatchSpec>,

/// Additional constraints that should be satisfied by the solver.
/// Packages included in the `constraints` are not necessarily
/// installed, but they must be satisfied by the solution.
pub constraints: Vec<MatchSpec>,

/// The timeout after which the solver should stop
pub timeout: Option<std::time::Duration>,

Expand All @@ -156,6 +161,7 @@ impl<'r, I: IntoIterator<Item = &'r RepoDataRecord>> FromIterator<I>
pinned_packages: Vec::new(),
virtual_packages: Vec::new(),
specs: Vec::new(),
constraints: Vec::new(),
timeout: None,
channel_priority: ChannelPriority::default(),
exclude_newer: None,
Expand Down
5 changes: 5 additions & 0 deletions crates/rattler_solve/src/libsolv_c/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ impl super::SolverImpl for Solver {
goal.install(id, false);
}

for spec in task.constraints {
let id = pool.intern_matchspec(&spec);
goal.install(id, true);
}

// Construct a solver and solve the problems in the queue
let mut solver = pool.create_solver();
solver.set_flag(SolverFlag::allow_uninstall(), true);
Expand Down
19 changes: 15 additions & 4 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,11 +568,21 @@ impl super::SolverImpl for Solver {
})
.collect();

let root_constraints = task
.constraints
.iter()
.map(|spec| {
let (name, spec) = spec.clone().into_nameless();
let name = name.expect("cannot use matchspec without a name");
let name_id = provider.pool.intern_package_name(name.as_normalized());
provider.pool.intern_version_set(name_id, spec.into())
})
.collect();

// Construct a solver and solve the problems in the queue
let mut solver = LibSolvRsSolver::new(provider);
let solvables = solver
.solve(root_requirements)
.map_err(|unsolvable_or_cancelled| {
let solvables = solver.solve(root_requirements, root_constraints).map_err(
|unsolvable_or_cancelled| {
match unsolvable_or_cancelled {
UnsolvableOrCancelled::Unsolvable(problem) => {
SolveError::Unsolvable(vec![problem
Expand All @@ -583,7 +593,8 @@ impl super::SolverImpl for Solver {
// put a generic message in here for now
UnsolvableOrCancelled::Cancelled(_) => SolveError::Cancelled,
}
})?;
},
)?;

// Get the resulting packages from the solver.
let required_records = solvables
Expand Down
71 changes: 44 additions & 27 deletions crates/rattler_solve/tests/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,27 @@ macro_rules! solver_backend_tests {
_ => panic!("expected a DuplicateRecord error"),
}
}

#[test]
fn test_constraints() {
// There following package is provided as .tar.bz and as .conda in repodata.json
let mut operations = solve::<$T>(
dummy_channel_json_path(),
SimpleSolveTask {
specs: &["foobar"],
constraints: vec!["bors <=1", "nonexisting"],
..SimpleSolveTask::default()
},
)
.unwrap();

// Sort operations by file name to make the test deterministic
operations.sort_by(|a, b| a.file_name.cmp(&b.file_name));

assert_eq!(operations.len(), 2);
assert_eq!(operations[0].file_name, "bors-1.0-bla_1.tar.bz2");
assert_eq!(operations[1].file_name, "foobar-2.1-bla_1.tar.bz2");
}
};
}

Expand Down Expand Up @@ -592,6 +613,7 @@ mod libsolv_c {
virtual_packages: Vec::new(),
available_packages: [libsolv_repodata],
specs,
constraints: Vec::new(),
pinned_packages: Vec::new(),
timeout: None,
channel_priority: ChannelPriority::default(),
Expand Down Expand Up @@ -768,6 +790,7 @@ mod resolvo {
#[derive(Default)]
struct SimpleSolveTask<'a> {
specs: &'a [&'a str],
constraints: Vec<&'a str>,
installed_packages: Vec<RepoDataRecord>,
pinned_packages: Vec<RepoDataRecord>,
virtual_packages: Vec<GenericVirtualPackage>,
Expand All @@ -787,10 +810,17 @@ fn solve<T: SolverImpl + Default>(
.map(|m| MatchSpec::from_str(m, ParseStrictness::Lenient).unwrap())
.collect();

let constraints = task
.constraints
.into_iter()
.map(|m| MatchSpec::from_str(m, ParseStrictness::Lenient).unwrap())
.collect();

let task = SolverTask {
locked_packages: task.installed_packages,
virtual_packages: task.virtual_packages,
specs,
constraints,
pinned_packages: task.pinned_packages,
exclude_newer: task.exclude_newer,
strategy: task.strategy,
Expand Down Expand Up @@ -942,12 +972,11 @@ fn compare_solve_xtensor_xsimd() {
});
}

fn solve_to_get_channel_of_spec(
fn solve_to_get_channel_of_spec<T: SolverImpl + Default>(
spec_str: &str,
expected_channel: &str,
repo_data: Vec<&SparseRepoData>,
channel_priority: ChannelPriority,
use_resolvo: bool,
) {
let spec = MatchSpec::from_str(spec_str, ParseStrictness::Lenient).unwrap();
let specs = vec![spec.clone()];
Expand All @@ -962,11 +991,7 @@ fn solve_to_get_channel_of_spec(
..SolverTask::from_iter(&available_packages)
};

let result = if use_resolvo {
rattler_solve::resolvo::Solver.solve(task).unwrap()
} else {
rattler_solve::libsolv_c::Solver.solve(task).unwrap()
};
let result = T::default().solve(task).unwrap();

let record = result.iter().find(|record| {
record.package_record.name.as_normalized() == spec.name.as_ref().unwrap().as_normalized()
Expand All @@ -980,33 +1005,29 @@ fn channel_specific_requirement() {
read_conda_forge_sparse_repo_data(),
read_pytorch_sparse_repo_data(),
];
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"conda-forge::pytorch-cpu",
"https://conda.anaconda.org/conda-forge/",
repodata.clone(),
ChannelPriority::Strict,
true,
);
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"conda-forge::pytorch-cpu",
"https://conda.anaconda.org/conda-forge/",
repodata.clone(),
ChannelPriority::Disabled,
true,
);
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"pytorch::pytorch-cpu",
"https://conda.anaconda.org/pytorch/",
repodata.clone(),
ChannelPriority::Strict,
true,
);
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"pytorch::pytorch-cpu",
"https://conda.anaconda.org/pytorch/",
repodata,
ChannelPriority::Disabled,
true,
);
}

Expand All @@ -1017,25 +1038,23 @@ fn channel_priority_strict() {
read_conda_forge_sparse_repo_data(),
read_pytorch_sparse_repo_data(),
];
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"pytorch-cpu",
"https://conda.anaconda.org/conda-forge/",
repodata,
ChannelPriority::Strict,
true,
);

// Solve with pytorch as the first channel
let repodata = vec![
read_pytorch_sparse_repo_data(),
read_conda_forge_sparse_repo_data(),
];
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"pytorch-cpu",
"https://conda.anaconda.org/pytorch/",
repodata,
ChannelPriority::Strict,
true,
);
}

Expand All @@ -1051,12 +1070,11 @@ fn channel_priority_strict_panic() {
read_conda_forge_sparse_repo_data(),
read_pytorch_sparse_repo_data(),
];
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"pytorch-cpu=0.4.1=py36_cpu_1",
"https://conda.anaconda.org/pytorch/",
repodata,
ChannelPriority::Strict,
true,
);
}

Expand All @@ -1066,15 +1084,15 @@ fn channel_priority_disabled() {
read_conda_forge_sparse_repo_data(),
read_pytorch_sparse_repo_data(),
];
solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::resolvo::Solver>(
"pytorch-cpu=0.4.1=py36_cpu_1",
"https://conda.anaconda.org/pytorch/",
repodata,
ChannelPriority::Disabled,
true,
);
}

#[cfg(feature = "libsolv_c")]
#[test]
#[should_panic(
expected = "called `Result::unwrap()` on an `Err` value: Unsolvable([\"package \
Expand All @@ -1086,27 +1104,26 @@ fn channel_priority_strict_libsolv_c() {
read_pytorch_sparse_repo_data(),
];

solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::libsolv_c::Solver>(
"pytorch-cpu=0.4.1=py36_cpu_1",
"https://conda.anaconda.org/pytorch/",
repodata,
ChannelPriority::Strict,
false,
);
}

#[cfg(feature = "libsolv_c")]
#[test]
fn channel_priority_disabled_libsolv_c() {
let repodata = vec![
read_conda_forge_sparse_repo_data(),
read_pytorch_sparse_repo_data(),
];

solve_to_get_channel_of_spec(
solve_to_get_channel_of_spec::<rattler_solve::libsolv_c::Solver>(
"pytorch-cpu=0.4.1=py36_cpu_1",
"https://conda.anaconda.org/pytorch/",
repodata,
ChannelPriority::Disabled,
false,
);
}
38 changes: 36 additions & 2 deletions py-rattler/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions py-rattler/rattler/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ async def solve(
channel_priority: ChannelPriority = ChannelPriority.Strict,
exclude_newer: Optional[datetime.datetime] = None,
strategy: SolveStrategy = "highest",
constraints: Optional[List[MatchSpec | str]] = None,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
Expand Down Expand Up @@ -69,6 +70,9 @@ async def solve(
* `"lowest-direct"`: Select the lowest compatible version for all
direct dependencies but the highest compatible version of transitive
dependencies.
constraints: Additional constraints that should be satisfied by the solver.
Packages included in the `constraints` are not necessarily installed,
but they must be satisfied by the solution.
Returns:
Resolved list of `RepoDataRecord`s.
Expand Down Expand Up @@ -97,5 +101,11 @@ async def solve(
if exclude_newer
else None,
strategy=strategy,
constraints=[
constraint._match_spec if isinstance(constraint, MatchSpec) else PyMatchSpec(str(constraint), True)
for constraint in constraints
]
if constraints is not None
else [],
)
]
Loading

0 comments on commit 18e094b

Please sign in to comment.