From 3fff3d27da6448025f644a35f7c0a2d3ddad22b8 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 May 2024 17:21:01 +0100 Subject: [PATCH 1/3] add `cache_generator` option --- BBHX_Phenom.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/BBHX_Phenom.py b/BBHX_Phenom.py index 9a842dc..0a2dbab 100644 --- a/BBHX_Phenom.py +++ b/BBHX_Phenom.py @@ -8,7 +8,6 @@ from warnings import warn -@functools.lru_cache(maxsize=128) def get_waveform_genner(log_mf_min, run_phenomd=True): # See below where this function is called for description of how we handle # log_mf_min. @@ -19,6 +18,12 @@ def get_waveform_genner(log_mf_min, run_phenomd=True): return wave_gen +@functools.lru_cache(maxsize=128) +def cached_get_waveform_genner(log_mf_fin, run_phenomd=True): + """Cached version of get_waveform_genner""" + return get_waveform_genner(log_mf_fin, run_phenomd) + + @functools.lru_cache(maxsize=10) def cached_arange(start, stop, spacing): return np.arange(start, stop, spacing) @@ -129,6 +134,7 @@ def _bbhx_fd( direct=False, num_interp=100, interp_f_lower=1e-4, + cache_generator=True, **params ): @@ -157,6 +163,8 @@ def _bbhx_fd( interp_f_lower : float Lower frequency cutoff used for interpolation when computing the chirp time. + cache_generator : bool + If true, the BBHx waveform generator is cached based on Returns ------- @@ -269,9 +277,18 @@ def _bbhx_fd( # To solve this we *round* the *logarithm* of this mass-dependent start # frequency. The factor of 25 ensures reasonable spacing while doing this. # So we round down to the nearest 1/25 of the logarithm of the frequency - log_mf_min = int(math.log(f_min*MTSUN_SI*(m1+m2)) * 25) - - wave_gen = get_waveform_genner(log_mf_min, run_phenomd=run_phenomd) + log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25 + if cache_generator: + # Use int to round down + wave_gen = cached_get_waveform_genner( + int(log_mf_min), + run_phenomd=run_phenomd, + ) + else: + wave_gen = get_waveform_genner( + log_mf_min, + run_phenomd=run_phenomd, + ) if sample_points is None: if 'delta_f' in params and params['delta_f'] > 0: From 6cbca91e2ca9f25871281f4e47e986e6ac990ac3 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 May 2024 17:27:02 +0100 Subject: [PATCH 2/3] add test for `cache_generator` option --- tests.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests.py b/tests.py index 98846b0..7a43905 100644 --- a/tests.py +++ b/tests.py @@ -78,6 +78,28 @@ def test_phenomhm_mode_array(params, mode_array): assert len(wf) == 3 +@pytest.mark.parametrize("cache_generator", [False, True]) +def test_cache_generator(params, cache_generator): + from BBHX_Phenom import cached_get_waveform_genner + + params["approximant"] = "BBHX_PhenomD" + params["cache_generator"] = cache_generator + + # Build cache if using it + get_fd_det_waveform(**params) + + n_calls = 2 + for _ in range(n_calls): + get_fd_det_waveform(**params) + + cache_info = cached_get_waveform_genner.cache_info() + if cache_generator: + assert cache_info.hits == n_calls + else: + assert cache_info.hits == 0 + + + def test_length_in_time(params, approximant): params["approximant"] = approximant # print(params) From 4ce6a41f79b08f3b90417b00805d0f598fc110b9 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 21 May 2024 17:33:28 +0100 Subject: [PATCH 3/3] clear cache in test --- tests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests.py b/tests.py index 7a43905..17bf87b 100644 --- a/tests.py +++ b/tests.py @@ -82,6 +82,9 @@ def test_phenomhm_mode_array(params, mode_array): def test_cache_generator(params, cache_generator): from BBHX_Phenom import cached_get_waveform_genner + # Clear cache for these tests + cached_get_waveform_genner.cache_clear() + params["approximant"] = "BBHX_PhenomD" params["cache_generator"] = cache_generator