From 73e57cdbe2828c829522a8bac083205040048c71 Mon Sep 17 00:00:00 2001 From: teutoburg Date: Tue, 17 Sep 2024 15:05:47 +0200 Subject: [PATCH] Allow comparison of SpectralType directly to string If a string represents a valid constructor, it can be compare directly to an instance of SpectralType, which can in some cases be a lot more convenient than having to manually convert the string to a SpectralType. Implementing the gt and ge methods was necessary to also allow reverse comparisons, i.e. `"B0V" < SpectralType("A0V")`. --- astar_utils/spectral_types.py | 37 ++++++++++++++++++++++++++++------- tests/test_spectral_types.py | 29 +++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/astar_utils/spectral_types.py b/astar_utils/spectral_types.py index 420b98d..7e1e493 100644 --- a/astar_utils/spectral_types.py +++ b/astar_utils/spectral_types.py @@ -44,7 +44,9 @@ class SpectralType: In this context, the luminosity class (if any) is ignored for sorting and comparison (<, >, <=, >=), as it represents a second physical dimension. However, instances of this class may also be compared for equality (== and - !=), in which case all three attributes are considered. + !=), in which case all three attributes are considered. It is also possible + to compare instances directly to strings, if the string is a valid + construtor for this class. Attributes ---------- @@ -106,7 +108,7 @@ class SpectralType: def __post_init__(self, spectype) -> None: """Validate input and populate fields.""" if not (match := self._regex.fullmatch(spectype)): - raise ValueError(spectype) + raise ValueError(f"{spectype!r} is not a valid spectral type.") classes = match.groupdict() # Circumvent frozen as per the docs... @@ -178,14 +180,35 @@ def _comp_tuple(self) -> tuple[int, float]: sub_cls = 5 return (self._spec_cls_idx, sub_cls) + @classmethod + def _comp_guard(cls, other): + if isinstance(other, str): + other = cls(other) + if not isinstance(other, cls): + raise TypeError("Can only compare equal types or valid str.") + return other + + def __eq__(self, other) -> bool: + """Return self == other.""" + other = self._comp_guard(other) + return self._comp_tuple == other._comp_tuple + def __lt__(self, other) -> bool: """Return self < other.""" - if not isinstance(other, self.__class__): - raise TypeError("Can only compare equal types.") + other = self._comp_guard(other) return self._comp_tuple < other._comp_tuple def __le__(self, other) -> bool: - """Return self < other.""" - if not isinstance(other, self.__class__): - raise TypeError("Can only compare equal types.") + """Return self <= other.""" + other = self._comp_guard(other) return self._comp_tuple <= other._comp_tuple + + def __gt__(self, other) -> bool: + """Return self > other.""" + other = self._comp_guard(other) + return self._comp_tuple > other._comp_tuple + + def __ge__(self, other) -> bool: + """Return self >= other.""" + other = self._comp_guard(other) + return self._comp_tuple >= other._comp_tuple diff --git a/tests/test_spectral_types.py b/tests/test_spectral_types.py index a6a9e11..ad298dd 100644 --- a/tests/test_spectral_types.py +++ b/tests/test_spectral_types.py @@ -108,6 +108,35 @@ def test_throws_on_invalid_compare(self, operation): operation(SpectralType("A0V"), 42) +class TestComparesStr: + def test_lt(self): + assert SpectralType("A0V") < "A7V" + + def test_le(self): + assert SpectralType("A0V") <= "A0V" + + def test_gt(self): + assert SpectralType("A0V") > "B7V" + + def test_ge(self): + assert SpectralType("A0V") >= "A0V" + + def test_eq(self): + assert SpectralType("A0V") == "A0V" + + def test_ne(self): + assert SpectralType("A0V") != "A1V" + + def test_reverse_le(self): + assert "A0" <= SpectralType("A0V") + + def test_reverse_gt(self): + assert "A7" > SpectralType("A0V") + + def test_reverse_ne(self): + assert "A1" != SpectralType("A0V") + + class TestRepresentations: @pytest.mark.parametrize(("ssl_cls", "exptcted"), [("A0V", "SpectralType('A0V')"),