From c16560394524a8fd7bc17b5538a5f24c28247ddc Mon Sep 17 00:00:00 2001 From: teutoburg Date: Thu, 12 Sep 2024 12:40:36 +0200 Subject: [PATCH] Allow lower case in input --- astar_utils/spectral_types.py | 19 +++++++++++++++---- tests/test_spectral_types.py | 13 +++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/astar_utils/spectral_types.py b/astar_utils/spectral_types.py index 9da63d8..95feb43 100644 --- a/astar_utils/spectral_types.py +++ b/astar_utils/spectral_types.py @@ -55,6 +55,13 @@ class SpectralType: luminosity_class : str or None Roman numeral luminosity class (I-V). + Notes + ----- + The constructor string can be supplied in both upper or lower case or a + mixture thereof, meaning "A0V", "a0v", "A0v" and "a0V" are all valid + representations of the same spectral type. The internal attributes are + converted to uppercase upon creation. + Examples -------- >>> from astar_utils import SpectralType @@ -93,8 +100,8 @@ class SpectralType: spectype: InitVar[str] _cls_order: ClassVar = "OBAFGKM" # descending Teff _regex: ClassVar = re.compile( - r"^(?P[OBAFGKM])(?P\d(?:\.\d)?)?" - "(?PI{1,3}|IV|V)?$", re.A | re.I) + r"^(?P[OBAFGKM])(?P\d(\.\d)?)?" + "(?PI{1,3}|IV|V)?$", re.ASCII | re.IGNORECASE) def __post_init__(self, spectype) -> None: """Validate input and populate fields.""" @@ -103,13 +110,17 @@ def __post_init__(self, spectype) -> None: classes = match.groupdict() # Circumvent frozen as per the docs... - object.__setattr__(self, "spectral_class", classes["spec_cls"]) - object.__setattr__(self, "luminosity_class", classes["lum_cls"]) + object.__setattr__(self, "spectral_class", + str(classes["spec_cls"]).upper()) if classes["sub_cls"] is not None: object.__setattr__(self, "spectral_subclass", float(classes["sub_cls"])) + if classes["lum_cls"] is not None: + object.__setattr__(self, "luminosity_class", + str(classes["lum_cls"]).upper()) + @property def _subcls_str(self) -> str: if self.spectral_subclass is None: diff --git a/tests/test_spectral_types.py b/tests/test_spectral_types.py index 598c99b..02e534e 100644 --- a/tests/test_spectral_types.py +++ b/tests/test_spectral_types.py @@ -130,3 +130,16 @@ def test_repr(self, ssl_cls, exptcted): def test_str(self, ssl_cls, exptcted): spt = SpectralType(ssl_cls) assert str(spt) == exptcted + + +class TestUpperLowerCase: + def test_uppers_lower_case(self): + spt_a = SpectralType("A0V") + spt_b = SpectralType("a0v") + assert spt_a == spt_b + + @pytest.mark.parametrize("mixcase", ["m2.5III", "M2.5iii"]) + def test_uppers_lower_case_mixed(self, mixcase): + spt_a = SpectralType("M2.5III") + spt_b = SpectralType(mixcase) + assert spt_a == spt_b