Skip to content

Commit

Permalink
Allow lower case in input
Browse files Browse the repository at this point in the history
  • Loading branch information
teutoburg committed Sep 12, 2024
1 parent 7e59db0 commit c165603
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
19 changes: 15 additions & 4 deletions astar_utils/spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class SpectralType:
luminosity_class : str or None
Roman numeral luminosity class (I-V).
Notes
-----
The constructor string can be supplied in both upper or lower case or a
mixture thereof, meaning "A0V", "a0v", "A0v" and "a0V" are all valid
representations of the same spectral type. The internal attributes are
converted to uppercase upon creation.
Examples
--------
>>> from astar_utils import SpectralType
Expand Down Expand Up @@ -93,8 +100,8 @@ class SpectralType:
spectype: InitVar[str]
_cls_order: ClassVar = "OBAFGKM" # descending Teff
_regex: ClassVar = re.compile(
r"^(?P<spec_cls>[OBAFGKM])(?P<sub_cls>\d(?:\.\d)?)?"
"(?P<lum_cls>I{1,3}|IV|V)?$", re.A | re.I)
r"^(?P<spec_cls>[OBAFGKM])(?P<sub_cls>\d(\.\d)?)?"
"(?P<lum_cls>I{1,3}|IV|V)?$", re.ASCII | re.IGNORECASE)

def __post_init__(self, spectype) -> None:
"""Validate input and populate fields."""
Expand All @@ -103,13 +110,17 @@ def __post_init__(self, spectype) -> None:

classes = match.groupdict()
# Circumvent frozen as per the docs...
object.__setattr__(self, "spectral_class", classes["spec_cls"])
object.__setattr__(self, "luminosity_class", classes["lum_cls"])
object.__setattr__(self, "spectral_class",
str(classes["spec_cls"]).upper())

if classes["sub_cls"] is not None:
object.__setattr__(self, "spectral_subclass",
float(classes["sub_cls"]))

if classes["lum_cls"] is not None:
object.__setattr__(self, "luminosity_class",
str(classes["lum_cls"]).upper())

@property
def _subcls_str(self) -> str:
if self.spectral_subclass is None:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,16 @@ def test_repr(self, ssl_cls, exptcted):
def test_str(self, ssl_cls, exptcted):
spt = SpectralType(ssl_cls)
assert str(spt) == exptcted


class TestUpperLowerCase:
def test_uppers_lower_case(self):
spt_a = SpectralType("A0V")
spt_b = SpectralType("a0v")
assert spt_a == spt_b

@pytest.mark.parametrize("mixcase", ["m2.5III", "M2.5iii"])
def test_uppers_lower_case_mixed(self, mixcase):
spt_a = SpectralType("M2.5III")
spt_b = SpectralType(mixcase)
assert spt_a == spt_b

0 comments on commit c165603

Please sign in to comment.