Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix passing mode_array in injection-waveform-arguments #820

Merged
merged 9 commits into from
Jan 27, 2025
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,5 @@ Martin White
Peter Tsun-Ho Pang
Alexandre Sebastien Goettel
Ann-Kristin Malz
Lorenzo Pompili
Sean Hibbitt
11 changes: 10 additions & 1 deletion bilby/gw/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .utils import (lalsim_GetApproximantFromString,
lalsim_SimInspiralFD,
lalsim_SimInspiralChooseFDWaveform,
lalsim_SimInspiralChooseFDWaveformSequence)
lalsim_SimInspiralChooseFDWaveformSequence,
safe_cast_mode_to_int)

UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will
result in an error in future releases. Make sure all of the waveform kwargs are correctly
Expand Down Expand Up @@ -174,6 +175,13 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista
}

if mode_array is not None:
try:
mode_array = [tuple(map(safe_cast_mode_to_int, mode)) for mode in mode_array]
except (ValueError, TypeError) as e:
raise ValueError(
f"Unable to convert mode_array elements to tuples of ints. "
f"mode_array: {mode_array}, Error: {e}"
) from e
gwsignal_dict.update(ModeArray=mode_array)

# Pass extra waveform arguments to gwsignal
Expand Down Expand Up @@ -528,6 +536,7 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0):
if mode_array is not None:
mode_array_lal = lalsim.SimInspiralCreateModeArray()
for mode in mode_array:
mode = tuple(map(safe_cast_mode_to_int, mode))
lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1])
lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal)
return waveform_dictionary
Expand Down
35 changes: 35 additions & 0 deletions bilby/gw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,3 +1073,38 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1):
chi,
-1
)


def safe_cast_mode_to_int(value):
"""Converts a string or integer, representing a mode index in a mode array, to an integer.

Raises an error if the value is a float or any unsupported type.

Parameters
---------------
value:
The input value to be cast to an integer.

Returns
----------
int:
The converted integer.

Raises
---------
TypeError
If the input is a float or an unsupported type.
ValueError
If the string cannot be converted to an integer.
"""
if isinstance(value, int):
return value
elif isinstance(value, str):
try:
return int(value)
except ValueError:
raise ValueError(f"Cannot convert string '{value}' to an integer.")
elif isinstance(value, float):
raise TypeError("Conversion from float to int is not allowed.")
else:
raise TypeError(f"Unsupported type '{type(value).__name__}'.")
20 changes: 20 additions & 0 deletions test/gw/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,26 @@ def test_lalsim_SimInspiralChooseFDWaveform(self):
1.5,
)

def test_safe_cast_mode_to_int(self):
# Valid cases
self.assertEqual(gwutils.safe_cast_mode_to_int("2"), 2)
self.assertEqual(gwutils.safe_cast_mode_to_int("-3"), -3)
self.assertEqual(gwutils.safe_cast_mode_to_int(5), 5)

# Invalid string cases
with self.assertRaises(ValueError):
gwutils.safe_cast_mode_to_int("two")
with self.assertRaises(ValueError):
gwutils.safe_cast_mode_to_int("")

# Unsupported types
with self.assertRaises(TypeError):
gwutils.safe_cast_mode_to_int(2.0)
with self.assertRaises(TypeError):
gwutils.safe_cast_mode_to_int(2.0j)
with self.assertRaises(TypeError):
gwutils.safe_cast_mode_to_int(None)


class TestSkyFrameConversion(unittest.TestCase):

Expand Down
Loading