Skip to content

Commit

Permalink
feat: (re-)add solve_with_sparse_repodata (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra authored Jun 7, 2024
1 parent 3569b59 commit e97f37c
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 28 deletions.
20 changes: 10 additions & 10 deletions py-rattler/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion py-rattler/rattler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
PypiPackageData,
PypiPackageEnvironmentData,
)
from rattler.solver import solve
from rattler.solver import solve, solve_with_sparse_repodata

__version__ = _get_rattler_version()
del _get_rattler_version
Expand Down Expand Up @@ -72,6 +72,7 @@
"PypiPackageData",
"PypiPackageEnvironmentData",
"solve",
"solve_with_sparse_repodata",
"Platform",
"install",
"index",
Expand Down
4 changes: 2 additions & 2 deletions py-rattler/rattler/repo_data/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
self,
channel: Channel,
subdir: str,
path: os.PathLike[str],
path: os.PathLike[str] | str,
) -> None:
if not isinstance(channel, Channel):
raise TypeError(
Expand All @@ -37,7 +37,7 @@ def __init__(
"SparseRepoData constructor received unsupported type "
f" {type(path).__name__!r} for the `path` parameter"
)
self._sparse = PySparseRepoData(channel._channel, subdir, path)
self._sparse = PySparseRepoData(channel._channel, subdir, str(path))

def package_names(self) -> List[str]:
"""
Expand Down
4 changes: 2 additions & 2 deletions py-rattler/rattler/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from rattler.solver.solver import solve
from rattler.solver.solver import solve, solve_with_sparse_repodata

__all__ = ["solve"]
__all__ = ["solve", "solve_with_sparse_repodata"]
95 changes: 92 additions & 3 deletions py-rattler/rattler/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import datetime
from typing import List, Optional, Literal, Sequence

from rattler import Channel, Platform, VirtualPackage
from rattler import Channel, Platform, VirtualPackage, SparseRepoData
from rattler.match_spec.match_spec import MatchSpec

from rattler.channel import ChannelPriority
from rattler.rattler import py_solve, PyMatchSpec
from rattler.rattler import py_solve, PyMatchSpec, py_solve_with_sparse_repodata

from rattler.platform.platform import PlatformLiteral
from rattler.repo_data.gateway import Gateway
Expand All @@ -29,7 +29,7 @@ async def solve(
channel_priority: ChannelPriority = ChannelPriority.Strict,
exclude_newer: Optional[datetime.datetime] = None,
strategy: SolveStrategy = "highest",
constraints: Optional[List[MatchSpec | str]] = None,
constraints: Optional[Sequence[MatchSpec | str]] = None,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
Expand Down Expand Up @@ -114,3 +114,92 @@ async def solve(
else [],
)
]


async def solve_with_sparse_repodata(
specs: Sequence[MatchSpec | str],
sparse_repodata: Sequence[SparseRepoData],
locked_packages: Optional[Sequence[RepoDataRecord]] = None,
pinned_packages: Optional[Sequence[RepoDataRecord]] = None,
virtual_packages: Optional[Sequence[GenericVirtualPackage | VirtualPackage]] = None,
timeout: Optional[datetime.timedelta] = None,
channel_priority: ChannelPriority = ChannelPriority.Strict,
exclude_newer: Optional[datetime.datetime] = None,
strategy: SolveStrategy = "highest",
constraints: Optional[Sequence[MatchSpec | str]] = None,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
that should be present in the environment.
This function is similar to `solve` but instead of querying for repodata
with a `Gateway` object this function allows you to manually pass in the
repodata.
Arguments:
specs: A list of matchspec to solve.
sparse_repodata: The repodata to query for the packages.
locked_packages: Records of packages that are previously selected.
If the solver encounters multiple variants of a single
package (identified by its name), it will sort the records
and select the best possible version. However, if there
exists a locked version it will prefer that variant instead.
This is useful to reduce the number of packages that are
updated when installing new packages. Usually you add the
currently installed packages or packages from a lock-file here.
pinned_packages: Records of packages that are previously selected and CANNOT
be changed. If the solver encounters multiple variants of
a single package (identified by its name), it will sort the
records and select the best possible version. However, if
there is a variant available in the `pinned_packages` field it
will always select that version no matter what even if that
means other packages have to be downgraded.
virtual_packages: A list of virtual packages considered active.
channel_priority: (Default = ChannelPriority.Strict) When `ChannelPriority.Strict`
the channel that the package is first found in will be used as
the only channel for that package. When `ChannelPriority.Disabled`
it will search for every package in every channel.
timeout: The maximum time the solver is allowed to run.
exclude_newer: Exclude any record that is newer than the given datetime.
strategy: The strategy to use when multiple versions of a package are available.
* `"highest"`: Select the highest compatible version of all packages.
* `"lowest"`: Select the lowest compatible version of all packages.
* `"lowest-direct"`: Select the lowest compatible version for all
direct dependencies but the highest compatible version of transitive
dependencies.
constraints: Additional constraints that should be satisfied by the solver.
Packages included in the `constraints` are not necessarily installed,
but they must be satisfied by the solution.
Returns:
Resolved list of `RepoDataRecord`s.
"""

return [
RepoDataRecord._from_py_record(solved_package)
for solved_package in await py_solve_with_sparse_repodata(
specs=[spec._match_spec if isinstance(spec, MatchSpec) else PyMatchSpec(str(spec), True) for spec in specs],
sparse_repodata=[package._sparse for package in sparse_repodata],
locked_packages=[package._record for package in locked_packages or []],
pinned_packages=[package._record for package in pinned_packages or []],
virtual_packages=[
v_package.into_generic()._generic_virtual_package
if isinstance(v_package, VirtualPackage)
else v_package._generic_virtual_package
for v_package in virtual_packages or []
],
channel_priority=channel_priority.value,
timeout=int(timeout / datetime.timedelta(microseconds=1)) if timeout else None,
exclude_newer_timestamp_ms=int(exclude_newer.replace(tzinfo=datetime.timezone.utc).timestamp() * 1000)
if exclude_newer
else None,
strategy=strategy,
constraints=[
constraint._match_spec if isinstance(constraint, MatchSpec) else PyMatchSpec(str(constraint), True)
for constraint in constraints
]
if constraints is not None
else [],
)
]
4 changes: 3 additions & 1 deletion py-rattler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use repo_data::{
};
use run_exports_json::PyRunExportsJson;
use shell::{PyActivationResult, PyActivationVariables, PyActivator, PyShellEnum};
use solver::py_solve;
use solver::{py_solve, py_solve_with_sparse_repodata};
use version::PyVersion;
use virtual_package::PyVirtualPackage;

Expand Down Expand Up @@ -142,6 +142,8 @@ fn rattler(py: Python<'_>, m: &PyModule) -> PyResult<()> {

m.add_function(wrap_pyfunction!(py_solve, m).unwrap())
.unwrap();
m.add_function(wrap_pyfunction!(py_solve_with_sparse_repodata, m).unwrap())
.unwrap();
m.add_function(wrap_pyfunction!(get_rattler_version, m).unwrap())
.unwrap();
m.add_function(wrap_pyfunction!(py_install, m).unwrap())
Expand Down
95 changes: 87 additions & 8 deletions py-rattler/src/solver.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use chrono::DateTime;
use pyo3::exceptions::PyValueError;
use pyo3::{pyfunction, FromPyObject, PyAny, PyErr, PyResult, Python};
use pyo3::{exceptions::PyValueError, pyfunction, FromPyObject, PyAny, PyErr, PyResult, Python};
use pyo3_asyncio::tokio::future_into_py;
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{resolvo::Solver, RepoDataIter, SolveStrategy, SolverImpl, SolverTask};
use std::sync::Arc;
use tokio::task::JoinError;

use crate::channel::PyChannel;
use crate::platform::PyPlatform;
use crate::repo_data::gateway::PyGateway;
use crate::{
channel::PyChannelPriority, error::PyRattlerError,
generic_virtual_package::PyGenericVirtualPackage, match_spec::PyMatchSpec, record::PyRecord,
Wrap,
channel::{PyChannel, PyChannelPriority},
error::PyRattlerError,
generic_virtual_package::PyGenericVirtualPackage,
match_spec::PyMatchSpec,
platform::PyPlatform,
record::PyRecord,
repo_data::gateway::PyGateway,
PySparseRepoData, Wrap,
};

impl FromPyObject<'_> for Wrap<SolveStrategy> {
Expand Down Expand Up @@ -104,3 +107,79 @@ pub fn py_solve(
}
})
}

#[allow(clippy::too_many_arguments)]
#[pyfunction]
pub fn py_solve_with_sparse_repodata(
py: Python<'_>,
specs: Vec<PyMatchSpec>,
sparse_repodata: Vec<PySparseRepoData>,
constraints: Vec<PyMatchSpec>,
locked_packages: Vec<PyRecord>,
pinned_packages: Vec<PyRecord>,
virtual_packages: Vec<PyGenericVirtualPackage>,
channel_priority: PyChannelPriority,
timeout: Option<u64>,
exclude_newer_timestamp_ms: Option<i64>,
strategy: Option<Wrap<SolveStrategy>>,
) -> PyResult<&'_ PyAny> {
future_into_py(py, async move {
let exclude_newer = exclude_newer_timestamp_ms.and_then(DateTime::from_timestamp_millis);

let sparse_repodata = sparse_repodata
.into_iter()
.map(|s| s.inner.clone())
.collect::<Vec<_>>();

let solve_result = tokio::task::spawn_blocking(move || {
let package_names = specs
.iter()
.filter_map(|match_spec| match_spec.inner.name.clone());

let available_packages = SparseRepoData::load_records_recursive(
sparse_repodata.iter().map(Arc::as_ref),
package_names,
None,
)?;

let task = SolverTask {
available_packages: available_packages
.iter()
.map(RepoDataIter)
.collect::<Vec<_>>(),
locked_packages: locked_packages
.into_iter()
.map(TryInto::try_into)
.collect::<PyResult<Vec<_>>>()?,
pinned_packages: pinned_packages
.into_iter()
.map(TryInto::try_into)
.collect::<PyResult<Vec<_>>>()?,
virtual_packages: virtual_packages.into_iter().map(Into::into).collect(),
specs: specs.into_iter().map(Into::into).collect(),
constraints: constraints.into_iter().map(Into::into).collect(),
timeout: timeout.map(std::time::Duration::from_micros),
channel_priority: channel_priority.into(),
exclude_newer,
strategy: strategy.map_or_else(Default::default, |v| v.0),
};

Ok::<_, PyErr>(
Solver
.solve(task)
.map(|res| res.into_iter().map(Into::into).collect::<Vec<PyRecord>>())
.map_err(PyRattlerError::from)?,
)
})
.await;

match solve_result.map_err(JoinError::try_into_panic) {
Ok(solve_result) => Ok(solve_result?),
Err(Ok(payload)) => std::panic::resume_unwind(payload),
Err(Err(_err)) => Err(PyRattlerError::IoError(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"solver task was cancelled",
)))?,
}
})
}
Loading

0 comments on commit e97f37c

Please sign in to comment.