From 088240a9a1427ad5a969464d4fc42c83aeb91a43 Mon Sep 17 00:00:00 2001 From: mj-will Date: Thu, 22 Aug 2024 15:09:40 +0100 Subject: [PATCH 1/2] optimize caching for mf_min --- BBHX_Phenom.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/BBHX_Phenom.py b/BBHX_Phenom.py index 53a079d..ac5fb2a 100644 --- a/BBHX_Phenom.py +++ b/BBHX_Phenom.py @@ -288,21 +288,23 @@ def _bbhx_fd( # To solve this we *round* the *logarithm* of this mass-dependent start # frequency. The factor of 25 ensures reasonable spacing while doing this. # So we round down to the nearest 1/25 of the logarithm of the frequency - log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25 + # We only do this if `mf_min` is not specified. If it is then we set this + # None and can easily cache the generator. + if mf_min is None: + log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25 + if cache_generator: + log_mf_min = int(log_mf_min) + else: + log_mf_min = None if cache_generator: - if mf_min is not None: - raise RuntimeError( - "Cannot use `cache_generator` when `mf_min` is specified" - ) - # Use int to round down wave_gen = cached_get_waveform_genner( - int(log_mf_min), - mf_min=None, + log_mf_min=log_mf_min, + mf_min=mf_min, run_phenomd=run_phenomd, ) else: wave_gen = get_waveform_genner( - log_mf_min, + log_mf_min=log_mf_min, mf_min=mf_min, run_phenomd=run_phenomd, ) From fe750591f424ca12ff3a2735b45ca3dbcc234fb3 Mon Sep 17 00:00:00 2001 From: mj-will Date: Thu, 22 Aug 2024 15:09:53 +0100 Subject: [PATCH 2/2] add a test for caching with mf_min --- tests.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests.py b/tests.py index 17bf87b..ca08a8a 100644 --- a/tests.py +++ b/tests.py @@ -87,6 +87,7 @@ def test_cache_generator(params, cache_generator): params["approximant"] = "BBHX_PhenomD" params["cache_generator"] = cache_generator + params["mf_min"] = None # Build cache if using it get_fd_det_waveform(**params) @@ -101,6 +102,27 @@ def test_cache_generator(params, cache_generator): else: assert cache_info.hits == 0 +def test_cache_generator_mf_min(params): + from BBHX_Phenom import cached_get_waveform_genner + + # Clear cache for these tests + cached_get_waveform_genner.cache_clear() + + params["approximant"] = "BBHX_PhenomD" + params["cache_generator"] = True + params["mf_min"] = 1e-4 + + masses = [2e6, 3e6] + + # Build cache + get_fd_det_waveform(**params) + + for m in masses: + params["mass1"] = m + get_fd_det_waveform(**params) + + cache_info = cached_get_waveform_genner.cache_info() + assert cache_info.hits == len(masses) def test_length_in_time(params, approximant):