Skip to content

Commit

Permalink
Merge pull request #20 from mj-will/optimize-caching
Browse files Browse the repository at this point in the history
Optimize caching
  • Loading branch information
WuShichao authored Aug 23, 2024
2 parents aaf6478 + fe75059 commit 62a62c8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
20 changes: 11 additions & 9 deletions BBHX_Phenom.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,21 +288,23 @@ def _bbhx_fd(
# To solve this we *round* the *logarithm* of this mass-dependent start
# frequency. The factor of 25 ensures reasonable spacing while doing this.
# So we round down to the nearest 1/25 of the logarithm of the frequency
log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25
# We only do this if `mf_min` is not specified. If it is then we set this
# None and can easily cache the generator.
if mf_min is None:
log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25
if cache_generator:
log_mf_min = int(log_mf_min)
else:
log_mf_min = None
if cache_generator:
if mf_min is not None:
raise RuntimeError(
"Cannot use `cache_generator` when `mf_min` is specified"
)
# Use int to round down
wave_gen = cached_get_waveform_genner(
int(log_mf_min),
mf_min=None,
log_mf_min=log_mf_min,
mf_min=mf_min,
run_phenomd=run_phenomd,
)
else:
wave_gen = get_waveform_genner(
log_mf_min,
log_mf_min=log_mf_min,
mf_min=mf_min,
run_phenomd=run_phenomd,
)
Expand Down
22 changes: 22 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_cache_generator(params, cache_generator):

params["approximant"] = "BBHX_PhenomD"
params["cache_generator"] = cache_generator
params["mf_min"] = None

# Build cache if using it
get_fd_det_waveform(**params)
Expand All @@ -101,6 +102,27 @@ def test_cache_generator(params, cache_generator):
else:
assert cache_info.hits == 0

def test_cache_generator_mf_min(params):
from BBHX_Phenom import cached_get_waveform_genner

# Clear cache for these tests
cached_get_waveform_genner.cache_clear()

params["approximant"] = "BBHX_PhenomD"
params["cache_generator"] = True
params["mf_min"] = 1e-4

masses = [2e6, 3e6]

# Build cache
get_fd_det_waveform(**params)

for m in masses:
params["mass1"] = m
get_fd_det_waveform(**params)

cache_info = cached_get_waveform_genner.cache_info()
assert cache_info.hits == len(masses)


def test_length_in_time(params, approximant):
Expand Down

0 comments on commit 62a62c8

Please sign in to comment.