From 0ebf360114450426e2552ae7d939cfdd6d61f117 Mon Sep 17 00:00:00 2001
From: Lorenzo Pompili <109740755+lorenzopompili00@users.noreply.github.com>
Date: Mon, 27 Jan 2025 20:24:07 +0100
Subject: [PATCH] BUG: Fix passing mode_array in injection-waveform-arguments 
 (#820)

* Fix passing mode_array in injection-waveform-arguments

* Add name to authors.md

* Put conversion inside if statement

* Add try/except

* Typo

* Add more informative error message

* Add safe_cast_mode_to_int

---------

Co-authored-by: Lorenzo Pompili <lorenzo.pompili@aei.mpg.de>
Co-authored-by: Colm Talbot <talbotcolm@gmail.com>
---
 AUTHORS.md            |  1 +
 bilby/gw/source.py    | 11 ++++++++++-
 bilby/gw/utils.py     | 35 +++++++++++++++++++++++++++++++++++
 test/gw/utils_test.py | 20 ++++++++++++++++++++
 4 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/AUTHORS.md b/AUTHORS.md
index 406638997..551b11fa4 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -104,4 +104,5 @@ Martin White
 Peter Tsun-Ho Pang
 Alexandre Sebastien Goettel
 Ann-Kristin Malz
+Lorenzo Pompili
 Sean Hibbitt
diff --git a/bilby/gw/source.py b/bilby/gw/source.py
index bd2cd1b64..ba2dd08ee 100644
--- a/bilby/gw/source.py
+++ b/bilby/gw/source.py
@@ -6,7 +6,8 @@
 from .utils import (lalsim_GetApproximantFromString,
                     lalsim_SimInspiralFD,
                     lalsim_SimInspiralChooseFDWaveform,
-                    lalsim_SimInspiralChooseFDWaveformSequence)
+                    lalsim_SimInspiralChooseFDWaveformSequence,
+                    safe_cast_mode_to_int)
 
 UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will
 result in an error in future releases. Make sure all of the waveform kwargs are correctly
@@ -174,6 +175,13 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista
                      }
 
     if mode_array is not None:
+        try:
+            mode_array = [tuple(map(safe_cast_mode_to_int, mode)) for mode in mode_array]
+        except (ValueError, TypeError) as e:
+            raise ValueError(
+                f"Unable to convert mode_array elements to tuples of ints. "
+                f"mode_array: {mode_array}, Error: {e}"
+            ) from e
         gwsignal_dict.update(ModeArray=mode_array)
 
     # Pass extra waveform arguments to gwsignal
@@ -528,6 +536,7 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0):
     if mode_array is not None:
         mode_array_lal = lalsim.SimInspiralCreateModeArray()
         for mode in mode_array:
+            mode = tuple(map(safe_cast_mode_to_int, mode))
             lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1])
         lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal)
     return waveform_dictionary
diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py
index bbbc6036a..f1f4c0291 100644
--- a/bilby/gw/utils.py
+++ b/bilby/gw/utils.py
@@ -1073,3 +1073,38 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1):
         chi,
         -1
     )
+
+
+def safe_cast_mode_to_int(value):
+    """Converts a string or integer, representing a mode index in a mode array, to an integer.
+
+    Raises an error if the value is a float or any unsupported type.
+
+    Parameters
+    ---------------
+    value:
+         The input value to be cast to an integer.
+
+    Returns
+    ----------
+    int:
+        The converted integer.
+
+    Raises
+    ---------
+    TypeError
+         If the input is a float or an unsupported type.
+    ValueError
+        If the string cannot be converted to an integer.
+    """
+    if isinstance(value, int):
+        return value
+    elif isinstance(value, str):
+        try:
+            return int(value)
+        except ValueError:
+            raise ValueError(f"Cannot convert string '{value}' to an integer.")
+    elif isinstance(value, float):
+        raise TypeError("Conversion from float to int is not allowed.")
+    else:
+        raise TypeError(f"Unsupported type '{type(value).__name__}'.")
diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py
index bb842cc80..cf78849c7 100644
--- a/test/gw/utils_test.py
+++ b/test/gw/utils_test.py
@@ -243,6 +243,26 @@ def test_lalsim_SimInspiralChooseFDWaveform(self):
                 1.5,
             )
 
+    def test_safe_cast_mode_to_int(self):
+        # Valid cases
+        self.assertEqual(gwutils.safe_cast_mode_to_int("2"), 2)
+        self.assertEqual(gwutils.safe_cast_mode_to_int("-3"), -3)
+        self.assertEqual(gwutils.safe_cast_mode_to_int(5), 5)
+
+        # Invalid string cases
+        with self.assertRaises(ValueError):
+            gwutils.safe_cast_mode_to_int("two")
+        with self.assertRaises(ValueError):
+            gwutils.safe_cast_mode_to_int("")
+
+        # Unsupported types
+        with self.assertRaises(TypeError):
+            gwutils.safe_cast_mode_to_int(2.0)
+        with self.assertRaises(TypeError):
+            gwutils.safe_cast_mode_to_int(2.0j)
+        with self.assertRaises(TypeError):
+            gwutils.safe_cast_mode_to_int(None)
+
 
 class TestSkyFrameConversion(unittest.TestCase):