Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor constraints module #537

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions src/fromager/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing

from packaging.requirements import Requirement
from packaging.utils import canonicalize_name
from packaging.utils import NormalizedName, canonicalize_name
from packaging.version import Version

from . import requirements_file
Expand All @@ -12,8 +12,33 @@


class Constraints:
def __init__(self, data: dict[str, Requirement]):
self._data = {canonicalize_name(n): v for n, v in data.items()}
def __init__(self) -> None:
# mapping of canonical names to requirements
# NOTE: Requirement.name is not normalized
self._data: dict[NormalizedName, Requirement] = {}

def __iter__(self) -> typing.Iterable[NormalizedName]:
yield from self._data

def add_constraint(self, unparsed: str) -> None:
"""Add new constraint, must not conflict with any existing constraints"""
req = Requirement(unparsed)
canon_name = canonicalize_name(req.name)
previous = self._data.get(canon_name)
if previous is not None:
raise KeyError(
f"{canon_name}: new constraint '{req}' conflicts with '{previous}'"
)
if requirements_file.evaluate_marker(req, req):
logger.debug(f"adding constraint {req}")
self._data[canon_name] = req

def load_constraints_file(self, constraints_file: str | pathlib.Path) -> None:
"""Load constraints from a constraints file"""
logger.info("loading constraints from %s", constraints_file)
content = requirements_file.parse_requirements_file(constraints_file)
for line in content:
self.add_constraint(line)

def get_constraint(self, name: str) -> Requirement | None:
return self._data.get(canonicalize_name(name))
Expand All @@ -29,20 +54,3 @@ def is_satisfied_by(self, pkg_name: str, version: Version) -> bool:
if constraint:
return constraint.specifier.contains(version, prereleases=True)
return True


def _parse(content: typing.Iterable[str]) -> Constraints:
constraints = {}
for line in content:
req = Requirement(line)
if requirements_file.evaluate_marker(req, req):
constraints[req.name] = req
return Constraints(constraints)


def load(constraints_file: str | pathlib.Path | None) -> Constraints:
if not constraints_file:
return Constraints({})
logger.info("loading constraints from %s", constraints_file)
parsed_req_file = requirements_file.parse_requirements_file(constraints_file)
return _parse(parsed_req_file)
4 changes: 2 additions & 2 deletions src/fromager/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def __init__(
)
self.settings = active_settings
self.input_constraints_uri: str | None
self.constraints = constraints.Constraints()
if constraints_file is not None:
self.input_constraints_uri = constraints_file
self.constraints = constraints.load(constraints_file)
self.constraints.load_constraints_file(constraints_file)
else:
self.input_constraints_uri = None
self.constraints = constraints.Constraints({})
self.sdists_repo = pathlib.Path(sdists_repo).absolute()
self.sdists_downloads = self.sdists_repo / "downloads"
self.sdists_builds = self.sdists_repo / "builds"
Expand Down
2 changes: 1 addition & 1 deletion src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(
self.include_sdists = include_sdists
self.include_wheels = include_wheels
self.sdist_server_url = sdist_server_url
self.constraints = constraints or Constraints({})
self.constraints = constraints or Constraints()
self.req_type = req_type

def identify(self, requirement_or_candidate: Requirement | Candidate) -> str:
Expand Down
53 changes: 39 additions & 14 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pathlib
from unittest.mock import Mock, patch

import pytest
from packaging.requirements import Requirement
Expand All @@ -9,40 +8,66 @@


def test_constraint_is_satisfied_by():
c = constraints.Constraints({"foo": Requirement("foo<=1.1")})
c = constraints.Constraints()
c.add_constraint("foo<=1.1")
assert c.is_satisfied_by("foo", "1.1")
assert c.is_satisfied_by("foo", Version("1.0"))
assert c.is_satisfied_by("bar", Version("2.0"))


def test_constraint_canonical_name():
c = constraints.Constraints({"flash_attn": Requirement("flash_attn<=1.1")})
c = constraints.Constraints()
c.add_constraint("flash_attn<=1.1")
assert c.is_satisfied_by("flash_attn", "1.1")
assert c.is_satisfied_by("flash-attn", "1.1")
assert c.is_satisfied_by("Flash-ATTN", "1.1")
assert list(c) == ["flash-attn"]


def test_constraint_not_is_satisfied_by():
c = constraints.Constraints({"foo": Requirement("foo<=1.1")})
c = constraints.Constraints()
c.add_constraint("foo<=1.1")
c.add_constraint("bar>=2.0")
assert not c.is_satisfied_by("foo", "1.2")
assert not c.is_satisfied_by("foo", Version("2.0"))
assert not c.is_satisfied_by("bar", Version("1.0"))


def test_load_empty_constraints_file():
assert constraints.load(None)._data == {}
def test_add_constraint_conflict():
c = constraints.Constraints()
c.add_constraint("foo<=1.1")
c.add_constraint("flit_core==2.0rc3")
with pytest.raises(KeyError):
c.add_constraint("foo<=1.1")
with pytest.raises(KeyError):
c.add_constraint("foo>1.1")
with pytest.raises(KeyError):
c.add_constraint("flit_core>2.0.0")
with pytest.raises(KeyError):
c.add_constraint("flit-core>2.0.0")


def test_allow_prerelease():
c = constraints.Constraints()
c.add_constraint("foo>=1.1")
assert not c.allow_prerelease("foo")
c.add_constraint("bar>=1.1a0")
assert c.allow_prerelease("bar")
c.add_constraint("flit_core==2.0rc3")
assert c.allow_prerelease("flit_core")


def test_load_non_existant_constraints_file(tmp_path: pathlib.Path):
non_existant_file = tmp_path / "non_existant.txt"
c = constraints.Constraints()
with pytest.raises(FileNotFoundError):
constraints.load(non_existant_file)
c.load_constraints_file(non_existant_file)


@patch("fromager.requirements_file.parse_requirements_file")
def test_load_constraints_file(parse_requirements_file: Mock, tmp_path: pathlib.Path):
def test_load_constraints_file(tmp_path: pathlib.Path):
constraint_file = tmp_path / "constraint.txt"
constraint_file.write_text("a\n")
parse_requirements_file.return_value = ["torch==3.1.0"]
assert constraints.load(constraint_file)._data == {
"torch": Requirement("torch==3.1.0")
}
constraint_file.write_text("egg\ntorch==3.1.0 # comment\n")
c = constraints.Constraints()
c.load_constraints_file(constraint_file)
assert list(c) == ["egg", "torch"] # type: ignore
assert c.get_constraint("torch") == Requirement("torch==3.1.0")
17 changes: 10 additions & 7 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ def test_provider_choose_sdist():


def test_provider_choose_either_with_constraint():
constraint = constraints.Constraints(
{"hydra-core": Requirement("hydra-core==1.3.2")}
)
constraint = constraints.Constraints()
constraint.add_constraint("hydra-core==1.3.2")
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
Expand All @@ -204,7 +203,8 @@ def test_provider_choose_either_with_constraint():


def test_provider_constraint_mismatch():
constraint = constraints.Constraints({"hydra-core": Requirement("hydra-core<=1.1")})
constraint = constraints.Constraints()
constraint.add_constraint("hydra-core<=1.1")
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
Expand All @@ -220,7 +220,8 @@ def test_provider_constraint_mismatch():


def test_provider_constraint_match():
constraint = constraints.Constraints({"hydra-core": Requirement("hydra-core<=1.3")})
constraint = constraints.Constraints()
constraint.add_constraint("hydra-core<=1.3")
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
Expand Down Expand Up @@ -525,7 +526,8 @@ def test_resolve_github():


def test_github_constraint_mismatch():
constraint = constraints.Constraints({"fromager": Requirement("fromager>=1.0")})
constraint = constraints.Constraints()
constraint.add_constraint("fromager>=1.0")
with requests_mock.Mocker() as r:
r.get(
"https://api.github.com:443/repos/python-wheel-build/fromager",
Expand All @@ -547,7 +549,8 @@ def test_github_constraint_mismatch():


def test_github_constraint_match():
constraint = constraints.Constraints({"fromager": Requirement("fromager<0.9")})
constraint = constraints.Constraints()
constraint.add_constraint("fromager<0.9")
with requests_mock.Mocker() as r:
r.get(
"https://api.github.com:443/repos/python-wheel-build/fromager",
Expand Down
Loading