Skip to content

Commit

Permalink
Update pycbc_brute_bank to include option to cut wavelength and save … (
Browse files Browse the repository at this point in the history
gwastro#4506)

* Update pycbc_brute_bank to include option to cut wavelength and save the lower frequency

* Removed Logging

* Update pycbc_brute_bank

Still need to fix the completion percentage

* Update pycbc_brute_bank

* Update pycbc_brute_bank

* Update pycbc_brute_bank

Debugged test example

* Update pycbc_brute_bank

* Update pycbc_brute_bank

* Testing Condor

Condor installation is not working properly for the checks after I changed two lines.

* Testing 

Testing with file used before swapping two lines caused errors

* Update pycbc_brute_bank

Swapped two lines to be within the if statment
  • Loading branch information
kkacanja authored and maxtrevor committed Dec 5, 2023
1 parent 0c3b7b5 commit b72cfa8
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions bin/bank/pycbc_brute_bank
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy, h5py, logging, argparse, numpy.random
import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
from pycbc import transforms
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc.distributions import read_params_from_config
from pycbc.distributions.utils import draw_samples_from_config, prior_from_config
from scipy.stats import gaussian_kde
Expand All @@ -45,6 +46,8 @@ parser.add_argument('--approximant', required=True,
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
help='size of waveform buffer in seconds')
parser.add_argument('--max-signal-length', type= float,
help="When specified, it cuts the maximum length of the waveform model to the lengh provided")
parser.add_argument('--sample-rate', default=2048, type=float,
help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
Expand Down Expand Up @@ -269,15 +272,28 @@ class GenUniformWaveform(object):
self.md = q._data[-100:]
self.md2 = q._data[0:100]

def generate(self, **kwds):
def generate(self, **kwds):
kwds.update(fdict)
if kwds['approximant'] in pycbc.waveform.fd_approximants():
ws = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_lower=self.f_lower, **kwds)
hp = ws[0]
hc = ws[1]
if args.max_signal_length is not None:
flow = numpy.arange(self.f_lower, 100, .1)[::-1]
length = spa_length_in_time(mass1=kwds['mass1'], mass2=kwds['mass2'], f_lower=flow, phase_order=-1)
maxlen = args.max_signal_length
x = numpy.searchsorted(length, maxlen) - 1
l = length[x]
f = flow[x]
else:
f = self.f_lower

kwds['f_lower'] = f

if kwds['approximant'] in pycbc.waveform.fd_approximants():
hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
f_ref=10.0, **kwds)


if 'fratio' in kwds:
hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])

else:
dt = 1.0 / args.sample_rate
hp = pycbc.waveform.get_waveform_filter(
Expand Down Expand Up @@ -342,10 +358,10 @@ def draw(rtype):
p = bank.keys()
p = [k for k in p if k not in fdict]
p.remove('approximant')
p.remove('f_lower')
if args.input_config is not None:
p = variable_args
bdata = numpy.array([bank.key(k)[-trail:] for k in p])

kde = gaussian_kde(bdata)
points = kde.resample(size=size)
params = {k: v for k, v in zip(p, points)}
Expand Down Expand Up @@ -422,9 +438,10 @@ def cdraw(rtype, ts, te):
return None

return p

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl

go = True

region = 0
Expand All @@ -447,6 +464,7 @@ while tau0s < args.tau0_end:
if r > 10:
conv = uconv
kloop = 0

while ((kloop == 0) or (kconv / okconv) > .5) and len(bank) > 10:
r += 1
kloop += 1
Expand All @@ -455,9 +473,11 @@ while tau0s < args.tau0_end:
bank, kconv = bank.check_params(gen, params, args.minimal_match)
logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s",
region, kloop, r, len(bank), kconv, len(bank) - blen)


if uconv:
logging.info('Ratio of convergences: %2.3f' % (kconv / (uconv)))
logging.info('Progress: {:.0%} completed'.format(tau0s/args.tau0_end))
logging.info('Progress: {:.0%} completed'.format(tau0e/args.tau0_end))

if kloop == 1:
okconv = kconv
Expand All @@ -473,9 +493,9 @@ while tau0s < args.tau0_end:
tau0e += args.tau0_crawl / 2

o = h5py.File(args.output_file, 'w')

for k in bank.keys():
val = bank.key(k)
if val.dtype.char == 'U':
val = val.astype('bytes')
o[k] = val
o['f_lower'] = numpy.ones(len(val)) * args.low_frequency_cutoff

0 comments on commit b72cfa8

Please sign in to comment.