diff --git a/AUTHORS.md b/AUTHORS.md index 40663899..551b11fa 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -104,4 +104,5 @@ Martin White Peter Tsun-Ho Pang Alexandre Sebastien Goettel Ann-Kristin Malz +Lorenzo Pompili Sean Hibbitt diff --git a/bilby/gw/source.py b/bilby/gw/source.py index bd2cd1b6..ba2dd08e 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -6,7 +6,8 @@ from .utils import (lalsim_GetApproximantFromString, lalsim_SimInspiralFD, lalsim_SimInspiralChooseFDWaveform, - lalsim_SimInspiralChooseFDWaveformSequence) + lalsim_SimInspiralChooseFDWaveformSequence, + safe_cast_mode_to_int) UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will result in an error in future releases. Make sure all of the waveform kwargs are correctly @@ -174,6 +175,13 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista } if mode_array is not None: + try: + mode_array = [tuple(map(safe_cast_mode_to_int, mode)) for mode in mode_array] + except (ValueError, TypeError) as e: + raise ValueError( + f"Unable to convert mode_array elements to tuples of ints. " + f"mode_array: {mode_array}, Error: {e}" + ) from e gwsignal_dict.update(ModeArray=mode_array) # Pass extra waveform arguments to gwsignal @@ -528,6 +536,7 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0): if mode_array is not None: mode_array_lal = lalsim.SimInspiralCreateModeArray() for mode in mode_array: + mode = tuple(map(safe_cast_mode_to_int, mode)) lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1]) lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal) return waveform_dictionary diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index bbbc6036..f1f4c029 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -1073,3 +1073,38 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): chi, -1 ) + + +def safe_cast_mode_to_int(value): + """Converts a string or integer, representing a mode index in a mode array, to an integer. + + Raises an error if the value is a float or any unsupported type. + + Parameters + --------------- + value: + The input value to be cast to an integer. + + Returns + ---------- + int: + The converted integer. + + Raises + --------- + TypeError + If the input is a float or an unsupported type. + ValueError + If the string cannot be converted to an integer. + """ + if isinstance(value, int): + return value + elif isinstance(value, str): + try: + return int(value) + except ValueError: + raise ValueError(f"Cannot convert string '{value}' to an integer.") + elif isinstance(value, float): + raise TypeError("Conversion from float to int is not allowed.") + else: + raise TypeError(f"Unsupported type '{type(value).__name__}'.") diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index bb842cc8..cf78849c 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -243,6 +243,26 @@ def test_lalsim_SimInspiralChooseFDWaveform(self): 1.5, ) + def test_safe_cast_mode_to_int(self): + # Valid cases + self.assertEqual(gwutils.safe_cast_mode_to_int("2"), 2) + self.assertEqual(gwutils.safe_cast_mode_to_int("-3"), -3) + self.assertEqual(gwutils.safe_cast_mode_to_int(5), 5) + + # Invalid string cases + with self.assertRaises(ValueError): + gwutils.safe_cast_mode_to_int("two") + with self.assertRaises(ValueError): + gwutils.safe_cast_mode_to_int("") + + # Unsupported types + with self.assertRaises(TypeError): + gwutils.safe_cast_mode_to_int(2.0) + with self.assertRaises(TypeError): + gwutils.safe_cast_mode_to_int(2.0j) + with self.assertRaises(TypeError): + gwutils.safe_cast_mode_to_int(None) + class TestSkyFrameConversion(unittest.TestCase):