Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to skip generator caching #16

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions BBHX_Phenom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from warnings import warn


@functools.lru_cache(maxsize=128)
def get_waveform_genner(log_mf_min, run_phenomd=True):
# See below where this function is called for description of how we handle
# log_mf_min.
Expand All @@ -19,6 +18,12 @@ def get_waveform_genner(log_mf_min, run_phenomd=True):
return wave_gen


@functools.lru_cache(maxsize=128)
def cached_get_waveform_genner(log_mf_fin, run_phenomd=True):
"""Cached version of get_waveform_genner"""
return get_waveform_genner(log_mf_fin, run_phenomd)


@functools.lru_cache(maxsize=10)
def cached_arange(start, stop, spacing):
return np.arange(start, stop, spacing)
Expand Down Expand Up @@ -129,6 +134,7 @@ def _bbhx_fd(
direct=False,
num_interp=100,
interp_f_lower=1e-4,
cache_generator=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to make the default False. Easier to avoid issues if the user has to "opt-in" to use this

**params
):

Expand Down Expand Up @@ -157,6 +163,8 @@ def _bbhx_fd(
interp_f_lower : float
Lower frequency cutoff used for interpolation when computing the
chirp time.
cache_generator : bool
If true, the BBHx waveform generator is cached based on

Returns
-------
Expand Down Expand Up @@ -269,9 +277,18 @@ def _bbhx_fd(
# To solve this we *round* the *logarithm* of this mass-dependent start
# frequency. The factor of 25 ensures reasonable spacing while doing this.
# So we round down to the nearest 1/25 of the logarithm of the frequency
log_mf_min = int(math.log(f_min*MTSUN_SI*(m1+m2)) * 25)

wave_gen = get_waveform_genner(log_mf_min, run_phenomd=run_phenomd)
log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Why do you remove int()? Previously, we used it to round down.

if cache_generator:
# Use int to round down
wave_gen = cached_get_waveform_genner(
int(log_mf_min),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK, you moved it here.

run_phenomd=run_phenomd,
)
else:
wave_gen = get_waveform_genner(
log_mf_min,
run_phenomd=run_phenomd,
)

if sample_points is None:
if 'delta_f' in params and params['delta_f'] > 0:
Expand Down
25 changes: 25 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,31 @@ def test_phenomhm_mode_array(params, mode_array):
assert len(wf) == 3


@pytest.mark.parametrize("cache_generator", [False, True])
def test_cache_generator(params, cache_generator):
from BBHX_Phenom import cached_get_waveform_genner

# Clear cache for these tests
cached_get_waveform_genner.cache_clear()

params["approximant"] = "BBHX_PhenomD"
params["cache_generator"] = cache_generator

# Build cache if using it
get_fd_det_waveform(**params)

n_calls = 2
for _ in range(n_calls):
get_fd_det_waveform(**params)

cache_info = cached_get_waveform_genner.cache_info()
if cache_generator:
assert cache_info.hits == n_calls
else:
assert cache_info.hits == 0



def test_length_in_time(params, approximant):
params["approximant"] = approximant
# print(params)
Expand Down