Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pycbc_brute_bank to include option to cut wavelength and save … #4506

Merged
merged 12 commits into from
Oct 25, 2023
42 changes: 27 additions & 15 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
help='size of waveform buffer in seconds')
parser.add_argument('--cut-wavelength', type=bool,
help="Option to cut the wavelength")
help="When enabled, the program will adjust the wavelength to fit the specified length in max-signal-length ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably combine these two options, so if max-signal-length isn't given, it doesn't impose any cut.

parser.add_argument('--max-signal-length', type= float,
help="maximum length of the waveform model")
parser.add_argument('--sample-rate', default=2048, type=float,
help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
Expand Down Expand Up @@ -273,31 +275,35 @@ class GenUniformWaveform(object):

def generate(self, **kwds):
kwds.update(fdict)
import numpy
from pycbc.waveform.spa_tmplt import spa_length_in_time
flow = numpy.arange(self.f_lower, 100, .1)[::-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line and the next should probably be within the if statement.

maxlen = args.max_signal_length
length = spa_length_in_time(mass1=kwds['mass1'], mass2=kwds['mass2'], f_lower=flow, phase_order=-1)

x = numpy.searchsorted(length, maxlen) - 1
l = length[x]
f = flow[x]
if kwds['approximant'] in pycbc.waveform.fd_approximants():
if args.cut_wavelength:
import numpy
from pycbc.waveform.spa_tmplt import spa_length_in_time
flow = numpy.arange(self.f_lower, 100, .1)[::-1]
maxlen = 512
length = spa_length_in_time(mass1=kwds['mass1'], mass2=kwds['mass2'], f_lower=flow, phase_order=-1)

x = numpy.searchsorted(length, maxlen) - 1
l = length[x]
f = flow[x]
#logging.info("cut wavelength is specified")
hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_lower=f, f_ref=10.0, **kwds)
kwds['flow'] = f
kwds['f_lower'] = f

if 'fratio' in kwds:
hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])

else:
import numpy
kwds.update(fdict)
#import numpy
#logging.info("cut wavelength not specified")
ws = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_lower=self.f_lower, **kwds)
hp = ws[0]
hc = ws[1]

kwds['f_lower'] = self.f_lower

if 'fratio' in kwds:
hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])
else:
Expand Down Expand Up @@ -447,6 +453,8 @@ def cdraw(rtype, ts, te):

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl
# total_it = args.tau0_crawl/(args.tau0_end - tau0s)
# completed_iterations = 0
go = True

region = 0
Expand All @@ -469,6 +477,7 @@ while tau0s < args.tau0_end:
if r > 10:
conv = uconv
kloop = 0

while ((kloop == 0) or (kconv / okconv) > .5) and len(bank) > 10:
r += 1
kloop += 1
Expand All @@ -477,9 +486,13 @@ while tau0s < args.tau0_end:
bank, kconv = bank.check_params(gen, params, args.minimal_match)
logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s",
region, kloop, r, len(bank), kconv, len(bank) - blen)

# completed_iterations += 1
# progress_percentage = completed_iterations / total_it

if uconv:
logging.info('Ratio of convergences: %2.3f' % (kconv / (uconv)))
logging.info('Progress: {:.0%} completed'.format(tau0s/args.tau0_end))
logging.info('Progress: {:.0%} completed'.format(tau0e/args.tau0_end))

if kloop == 1:
okconv = kconv
Expand All @@ -500,4 +513,3 @@ for k in bank.keys():
if val.dtype.char == 'U':
val = val.astype('bytes')
o[k] = val
o['f_lower'] = numpy.ones(len(val)) * args.low_frequency_cutoff