Skip to content

Commit

Permalink
Descope PR
Browse files Browse the repository at this point in the history
  • Loading branch information
GarethCabournDavies committed Oct 9, 2023
1 parent e652d2a commit d5003c3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 58 deletions.
97 changes: 48 additions & 49 deletions bin/all_sky_search/pycbc_bin_trigger_rates_dq
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import logging
import argparse
import pycbc
import pycbc.events
from pycbc.events import ranking, stat as pystat
from pycbc.events import stat as pystat
from pycbc.types.timeseries import load_timeseries
import numpy as np
import h5py as h5
Expand Down Expand Up @@ -50,51 +50,51 @@ logging.info('Start')

ifo = args.ifo

# Setup a data mask to remove any triggers with sngl_ranking below threshold

# These are the datasets we may need to calculate sngl_ranking
trig_stats = ['snr', 'chisq', 'chisq_dof', 'sg_chisq', 'psd_var_val']

def sngl_ranking_above_threshold(*trig_vals):
"""
Allow sngl_ranking to be compared to the triggers in the select function
"""
# Convert the triggers into a dictionary so it can be passed straight to
# the ranking calculation
trig_dict = {k: v for k, v in zip(trig_stats, trig_vals)}
sngl = ranking.get_sngls_ranking_from_trigs(trig_dict, args.sngl_ranking)
return sngl > args.stat_threshold
# Setup a data mask to remove any triggers with SNR below threshold
# This works as a pre-filter as SNR is always greater than or equal
# to sngl_ranking

with HFile(args.trig_file, 'r') as trig_file:
n_triggers_orig = trig_file[f'{ifo}/snr'].size
logging.info("Trigger file has %d triggers", n_triggers_orig)
logging.info('Generating trigger mask')
idx = trig_file.select(
sngl_ranking_above_threshold,
idx, _ = trig_file.select(
lambda snr: snr > args.stat_threshold if args,
*[f'{ifo}/{trig_s}' for trig_s in trig_stats],
indices_only=True
return_indices=True,
)
data_mask = np.zeros(n_triggers_orig, dtype=bool)
data_mask[idx] = True

logging.info("Getting %s triggers from file with %s > %.3f",
idx.size, args.sngl_ranking, args.stat_threshold)

with SingleDetTriggers(
args.trig_file,
None,
None,
None,
None,
ifo,
premask=data_mask
) as trigs:
# Extract the data we actually need from the data structure:
tmplt_ids = trigs.template_id
trig_times = trigs.end_time
stat = trigs.get_ranking(args.sngl_ranking)
logging.info("Getting %s triggers from file with SNR > %.3f",
idx.size, args.stat_threshold)

trigs = SingleDetTriggers(
args.trig_file,
None,
None,
None,
None,
ifo,
premask=data_mask
)

# Extract the data we actually need from the data structure:
tmplt_ids = trigs.template_id
trig_times = trigs.end_time
stat = trigs.get_ranking(args.sngl_ranking)
trig_times_int = trig_times.astype('int')

del trigs

logging.info("Cutting triggers with --sngl-ranking below threshold")
if not args.sngl_ranking == 'snr':
keep = stat > args.stat_threshold
tmplt_ids = tmplt_ids[keep]
trig_times = trig_times[keep]
stat = stat[keep]
trig_times_int = trig_times_int[keep]

n_triggers = tmplt_ids.size

dq_logl = np.array([])
Expand Down Expand Up @@ -146,28 +146,25 @@ with h5.File(args.bank_file, 'r') as bank:
locs_dict = {'all_bin': np.arange(0, len(bank['mass1'][:]), 1)}
locs_names = ['all_bin']

logging.info("Placing triggers into bins")
bin_triggers = {}
for bin_name in locs_names:
bin_locs = locs_dict[bin_name]
bin_triggers[bin_name] = np.isin(tmplt_ids, bin_locs)

if args.prune_number > 0:
for bin_name in locs_names:
logging.info('Pruning bin %s', bin_name)
trig_times_bin = trig_times[bin_triggers[bin_name]]
trig_stats_bin = stat[bin_triggers[bin_name]]
bin_locs = locs_dict[bin_name]
inbin = np.isin(tmplt_ids, bin_locs)
trig_times_bin = trig_times[inbin]
trig_stats_bin = stat[inbin]

for j in range(args.prune_number):
max_stat_arg = np.argmax(trig_stats_bin)
remove = np.nonzero(abs(trig_times_bin[max_stat_arg] - trig_times)
< args.prune_window)[0]
remove_inbin = np.nonzero(abs(trig_times_bin[max_stat_arg]
- trig_times_bin) < args.prune_window)[0]
remove = abs(trig_times_bin[max_stat_arg] - trig_times) \
< args.prune_window
logging.info("Prune %d: pruning %d triggers", j, sum(remove))
remove_inbin = abs(trig_times_bin[max_stat_arg]
- trig_times_bin) < args.prune_window
stat[remove] = 0
trig_stats_bin[remove_inbin] = 0
keep = np.nonzero(stat)[0]
logging.info("%d triggers removed through pruning", sum(keep))
keep = np.flatnonzero(stat)
logging.info("%d triggers removed through pruning", len(keep))
trig_times_int = trig_times_int[keep]
tmplt_ids = tmplt_ids[keep]
del stat
Expand All @@ -177,7 +174,9 @@ del trig_times

with h5.File(args.output_file, 'w') as f:
for bin_name in locs_names:
trig_times_bin = trig_times_int[bin_triggers[bin_name]]
bin_locs = locs_dict[bin_name]
inbin = np.isin(tmplt_ids, bin_locs)
trig_times_bin = trig_times_int[inbin]
trig_percentile = np.array([dq_percentiles_time[t]
for t in trig_times_bin])
logging.info('Processing %d triggers in bin %s',
Expand Down
14 changes: 5 additions & 9 deletions pycbc/io/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def select(self, fcn, *args, **kwds):
data[arg] = []

return_indices = kwds.get('return_indices', False)
indices_only = kwds.get('indices_only', False)
indices = np.array([], dtype=np.uint64)

# To conserve memory read the array in chunks
Expand All @@ -81,18 +80,15 @@ def select(self, fcn, *args, **kwds):
# Read each chunk's worth of data and find where it passes the function
partial = [refs[arg][i:r] for arg in args]
keep = fcn(*partial)
if return_indices or indices_only:
if return_indices:
indices = np.concatenate([indices, np.flatnonzero(keep) + i])

if not indices_only:
# Store only the results that pass the function
for arg, part in zip(args, partial):
data[arg].append(part[keep])
# Store only the results that pass the function
for arg, part in zip(args, partial):
data[arg].append(part[keep])

i += chunksize

if indices_only:
return indices.astype(np.uint64)
# Combine the partial results into full arrays
if len(args) == 1:
res = np.concatenate(data[args[0]])
Expand Down Expand Up @@ -418,7 +414,7 @@ def __init__(self, trig_file, bank_file, veto_file,
logging.info('Loading bank')
self.bank = HFile(bank_file, 'r')
else:
logging.info('No bank file given to SingleDetTriggers')
logging.info('No bank file given')
# empty dict in place of non-existent hdf file
self.bank = {}

Expand Down

0 comments on commit d5003c3

Please sign in to comment.