Skip to content

Commit

Permalink
Allow comparison of SpectralType directly to string
Browse files Browse the repository at this point in the history
If a string represents a valid constructor, it can be compare directly to
an instance of SpectralType, which can in some cases be a lot more
convenient than having to manually convert the string to a SpectralType.

Implementing the gt and ge methods was necessary to also allow reverse
comparisons, i.e. `"B0V" < SpectralType("A0V")`.
  • Loading branch information
teutoburg committed Sep 17, 2024
1 parent 4bff8c9 commit 73e57cd
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
37 changes: 30 additions & 7 deletions astar_utils/spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class SpectralType:
In this context, the luminosity class (if any) is ignored for sorting and
comparison (<, >, <=, >=), as it represents a second physical dimension.
However, instances of this class may also be compared for equality (== and
!=), in which case all three attributes are considered.
!=), in which case all three attributes are considered. It is also possible
to compare instances directly to strings, if the string is a valid
construtor for this class.
Attributes
----------
Expand Down Expand Up @@ -106,7 +108,7 @@ class SpectralType:
def __post_init__(self, spectype) -> None:
"""Validate input and populate fields."""
if not (match := self._regex.fullmatch(spectype)):
raise ValueError(spectype)
raise ValueError(f"{spectype!r} is not a valid spectral type.")

classes = match.groupdict()
# Circumvent frozen as per the docs...
Expand Down Expand Up @@ -178,14 +180,35 @@ def _comp_tuple(self) -> tuple[int, float]:
sub_cls = 5
return (self._spec_cls_idx, sub_cls)

@classmethod
def _comp_guard(cls, other):
if isinstance(other, str):
other = cls(other)
if not isinstance(other, cls):
raise TypeError("Can only compare equal types or valid str.")
return other

def __eq__(self, other) -> bool:
"""Return self == other."""
other = self._comp_guard(other)
return self._comp_tuple == other._comp_tuple

def __lt__(self, other) -> bool:
"""Return self < other."""
if not isinstance(other, self.__class__):
raise TypeError("Can only compare equal types.")
other = self._comp_guard(other)
return self._comp_tuple < other._comp_tuple

def __le__(self, other) -> bool:
"""Return self < other."""
if not isinstance(other, self.__class__):
raise TypeError("Can only compare equal types.")
"""Return self <= other."""
other = self._comp_guard(other)
return self._comp_tuple <= other._comp_tuple

def __gt__(self, other) -> bool:
"""Return self > other."""
other = self._comp_guard(other)
return self._comp_tuple > other._comp_tuple

def __ge__(self, other) -> bool:
"""Return self >= other."""
other = self._comp_guard(other)
return self._comp_tuple >= other._comp_tuple
29 changes: 29 additions & 0 deletions tests/test_spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,35 @@ def test_throws_on_invalid_compare(self, operation):
operation(SpectralType("A0V"), 42)


class TestComparesStr:
def test_lt(self):
assert SpectralType("A0V") < "A7V"

def test_le(self):
assert SpectralType("A0V") <= "A0V"

def test_gt(self):
assert SpectralType("A0V") > "B7V"

def test_ge(self):
assert SpectralType("A0V") >= "A0V"

def test_eq(self):
assert SpectralType("A0V") == "A0V"

def test_ne(self):
assert SpectralType("A0V") != "A1V"

def test_reverse_le(self):
assert "A0" <= SpectralType("A0V")

def test_reverse_gt(self):
assert "A7" > SpectralType("A0V")

def test_reverse_ne(self):
assert "A1" != SpectralType("A0V")


class TestRepresentations:
@pytest.mark.parametrize(("ssl_cls", "exptcted"),
[("A0V", "SpectralType('A0V')"),
Expand Down

0 comments on commit 73e57cd

Please sign in to comment.