From 0ebf360114450426e2552ae7d939cfdd6d61f117 Mon Sep 17 00:00:00 2001 From: Lorenzo Pompili <109740755+lorenzopompili00@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:24:07 +0100 Subject: [PATCH] BUG: Fix passing mode_array in injection-waveform-arguments (#820) * Fix passing mode_array in injection-waveform-arguments * Add name to authors.md * Put conversion inside if statement * Add try/except * Typo * Add more informative error message * Add safe_cast_mode_to_int --------- Co-authored-by: Lorenzo Pompili <lorenzo.pompili@aei.mpg.de> Co-authored-by: Colm Talbot <talbotcolm@gmail.com> --- AUTHORS.md | 1 + bilby/gw/source.py | 11 ++++++++++- bilby/gw/utils.py | 35 +++++++++++++++++++++++++++++++++++ test/gw/utils_test.py | 20 ++++++++++++++++++++ 4 files changed, 66 insertions(+), 1 deletion(-) diff --git a/AUTHORS.md b/AUTHORS.md index 406638997..551b11fa4 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -104,4 +104,5 @@ Martin White Peter Tsun-Ho Pang Alexandre Sebastien Goettel Ann-Kristin Malz +Lorenzo Pompili Sean Hibbitt diff --git a/bilby/gw/source.py b/bilby/gw/source.py index bd2cd1b64..ba2dd08ee 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -6,7 +6,8 @@ from .utils import (lalsim_GetApproximantFromString, lalsim_SimInspiralFD, lalsim_SimInspiralChooseFDWaveform, - lalsim_SimInspiralChooseFDWaveformSequence) + lalsim_SimInspiralChooseFDWaveformSequence, + safe_cast_mode_to_int) UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will result in an error in future releases. Make sure all of the waveform kwargs are correctly @@ -174,6 +175,13 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista } if mode_array is not None: + try: + mode_array = [tuple(map(safe_cast_mode_to_int, mode)) for mode in mode_array] + except (ValueError, TypeError) as e: + raise ValueError( + f"Unable to convert mode_array elements to tuples of ints. " + f"mode_array: {mode_array}, Error: {e}" + ) from e gwsignal_dict.update(ModeArray=mode_array) # Pass extra waveform arguments to gwsignal @@ -528,6 +536,7 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0): if mode_array is not None: mode_array_lal = lalsim.SimInspiralCreateModeArray() for mode in mode_array: + mode = tuple(map(safe_cast_mode_to_int, mode)) lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1]) lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal) return waveform_dictionary diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index bbbc6036a..f1f4c0291 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -1073,3 +1073,38 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): chi, -1 ) + + +def safe_cast_mode_to_int(value): + """Converts a string or integer, representing a mode index in a mode array, to an integer. + + Raises an error if the value is a float or any unsupported type. + + Parameters + --------------- + value: + The input value to be cast to an integer. + + Returns + ---------- + int: + The converted integer. + + Raises + --------- + TypeError + If the input is a float or an unsupported type. + ValueError + If the string cannot be converted to an integer. + """ + if isinstance(value, int): + return value + elif isinstance(value, str): + try: + return int(value) + except ValueError: + raise ValueError(f"Cannot convert string '{value}' to an integer.") + elif isinstance(value, float): + raise TypeError("Conversion from float to int is not allowed.") + else: + raise TypeError(f"Unsupported type '{type(value).__name__}'.") diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index bb842cc80..cf78849c7 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -243,6 +243,26 @@ def test_lalsim_SimInspiralChooseFDWaveform(self): 1.5, ) + def test_safe_cast_mode_to_int(self): + # Valid cases + self.assertEqual(gwutils.safe_cast_mode_to_int("2"), 2) + self.assertEqual(gwutils.safe_cast_mode_to_int("-3"), -3) + self.assertEqual(gwutils.safe_cast_mode_to_int(5), 5) + + # Invalid string cases + with self.assertRaises(ValueError): + gwutils.safe_cast_mode_to_int("two") + with self.assertRaises(ValueError): + gwutils.safe_cast_mode_to_int("") + + # Unsupported types + with self.assertRaises(TypeError): + gwutils.safe_cast_mode_to_int(2.0) + with self.assertRaises(TypeError): + gwutils.safe_cast_mode_to_int(2.0j) + with self.assertRaises(TypeError): + gwutils.safe_cast_mode_to_int(None) + class TestSkyFrameConversion(unittest.TestCase):