Skip to content

Commit

Permalink
BUG: Fix passing mode_array in injection-waveform-arguments (#820)
Browse files Browse the repository at this point in the history
* Fix passing mode_array in injection-waveform-arguments

* Add name to authors.md

* Put conversion inside if statement

* Add try/except

* Typo

* Add more informative error message

* Add safe_cast_mode_to_int

---------

Co-authored-by: Lorenzo Pompili <[email protected]>
Co-authored-by: Colm Talbot <[email protected]>
  • Loading branch information
3 people authored Jan 27, 2025
1 parent 32c044d commit 0ebf360
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 1 deletion.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,5 @@ Martin White
Peter Tsun-Ho Pang
Alexandre Sebastien Goettel
Ann-Kristin Malz
Lorenzo Pompili
Sean Hibbitt
11 changes: 10 additions & 1 deletion bilby/gw/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .utils import (lalsim_GetApproximantFromString,
lalsim_SimInspiralFD,
lalsim_SimInspiralChooseFDWaveform,
lalsim_SimInspiralChooseFDWaveformSequence)
lalsim_SimInspiralChooseFDWaveformSequence,
safe_cast_mode_to_int)

UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will
result in an error in future releases. Make sure all of the waveform kwargs are correctly
Expand Down Expand Up @@ -174,6 +175,13 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista
}

if mode_array is not None:
try:
mode_array = [tuple(map(safe_cast_mode_to_int, mode)) for mode in mode_array]
except (ValueError, TypeError) as e:
raise ValueError(
f"Unable to convert mode_array elements to tuples of ints. "
f"mode_array: {mode_array}, Error: {e}"
) from e
gwsignal_dict.update(ModeArray=mode_array)

# Pass extra waveform arguments to gwsignal
Expand Down Expand Up @@ -528,6 +536,7 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0):
if mode_array is not None:
mode_array_lal = lalsim.SimInspiralCreateModeArray()
for mode in mode_array:
mode = tuple(map(safe_cast_mode_to_int, mode))
lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1])
lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal)
return waveform_dictionary
Expand Down
35 changes: 35 additions & 0 deletions bilby/gw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,3 +1073,38 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1):
chi,
-1
)


def safe_cast_mode_to_int(value):
"""Converts a string or integer, representing a mode index in a mode array, to an integer.
Raises an error if the value is a float or any unsupported type.
Parameters
---------------
value:
The input value to be cast to an integer.
Returns
----------
int:
The converted integer.
Raises
---------
TypeError
If the input is a float or an unsupported type.
ValueError
If the string cannot be converted to an integer.
"""
if isinstance(value, int):
return value
elif isinstance(value, str):
try:
return int(value)
except ValueError:
raise ValueError(f"Cannot convert string '{value}' to an integer.")
elif isinstance(value, float):
raise TypeError("Conversion from float to int is not allowed.")
else:
raise TypeError(f"Unsupported type '{type(value).__name__}'.")
20 changes: 20 additions & 0 deletions test/gw/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,26 @@ def test_lalsim_SimInspiralChooseFDWaveform(self):
1.5,
)

def test_safe_cast_mode_to_int(self):
# Valid cases
self.assertEqual(gwutils.safe_cast_mode_to_int("2"), 2)
self.assertEqual(gwutils.safe_cast_mode_to_int("-3"), -3)
self.assertEqual(gwutils.safe_cast_mode_to_int(5), 5)

# Invalid string cases
with self.assertRaises(ValueError):
gwutils.safe_cast_mode_to_int("two")
with self.assertRaises(ValueError):
gwutils.safe_cast_mode_to_int("")

# Unsupported types
with self.assertRaises(TypeError):
gwutils.safe_cast_mode_to_int(2.0)
with self.assertRaises(TypeError):
gwutils.safe_cast_mode_to_int(2.0j)
with self.assertRaises(TypeError):
gwutils.safe_cast_mode_to_int(None)


class TestSkyFrameConversion(unittest.TestCase):

Expand Down

0 comments on commit 0ebf360

Please sign in to comment.