Skip to content

Commit

Permalink
Vetoes in PyGRB efficiency and page_tables scripts (#4978)
Browse files Browse the repository at this point in the history
* Vetoes in pycbc_pygrb_page_tables + some syntax streamlining

* Vetoes in pycbc_pygrb_efficiency + some syntax streamlining

* Squashed mchirp retrieval bug in page_tables

* PR review follow up: comprehension, comment, readability, unused variables

* Cleaner format_pvalue_str

* Cleaner comprehensions
  • Loading branch information
pannarale authored Dec 10, 2024
1 parent 1a7edd4 commit 1929bc2
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 291 deletions.
247 changes: 131 additions & 116 deletions bin/pygrb/pycbc_pygrb_efficiency
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ parser.add_argument("--bank-file", action="store", type=str, required=True,
help="Location of the full template bank used.")
ppu.pygrb_add_injmc_opts(parser)
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

Expand All @@ -144,6 +146,7 @@ if opts.exclusion_dist_output_file is not None or \
trig_file = opts.trig_file
onsource_file = opts.onsource_file
found_missed_file = opts.found_missed_file
veto_file = opts.veto_file
inj_set_name = opts.injection_set_name
wf_err = opts.waveform_error
cal_errs = {}
Expand Down Expand Up @@ -178,76 +181,84 @@ for output_file in [opts.exclusion_dist_output_file,
if output_file is not None:
outdir = os.path.split(os.path.abspath(output_file))[0]
if not os.path.isdir(outdir):
logging.info("Creating the output directoryi %s.", outdir)
logging.info("Creating the output directory %s.", outdir)
os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
opts.veto_category)

# Load triggers (apply reweighted SNR cut), time-slides, and segment dictionary
logging.info("Loading triggers.")
trigs = ppu.load_triggers(trig_file, ifos, vetoes,
rw_snr_threshold=opts.newsnr_threshold)
logging.info("%d offsource triggers surviving reweighted SNR cut.",
len(trigs['network/event_id']))
logging.info("Loading timeslides.")
slide_dict = ppu.load_time_slides(trig_file)
logging.info("Loading segments.")
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials
logging.info("Constructing trials.")
trial_dict = ppu.construct_trials(opts.seg_files, segment_dict,
ifos, slide_dict, vetoes)
total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict])
logging.info("%d trials generated.", total_trials)
# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Extract basic trigger properties and store as dictionaries
trig_time, trig_snr, trig_bestnr = \
ppu.extract_basic_trig_properties(trial_dict, trigs, slide_dict,
segment_dict, opts)

# Calculate BestNR values and maximum
time_veto_max_bestnr = {}
# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

for slide_id in slide_dict:
num_slide_segs = len(trial_dict[slide_id])
time_veto_max_bestnr[slide_id] = np.zeros(num_slide_segs)
# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
opts.seg_files,
segment_dict,
ifos,
slide_dict,
veto_file
)

# Load triggers (apply reweighted SNR cut, not vetoes)
all_off_trigs = ppu.load_data(trig_file, ifos, data_tag='offsource',
rw_snr_threshold=opts.newsnr_threshold,
slide_id=opts.slide_id)

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
keys = ['network/end_time_gc', 'network/reweighted_snr']
trig_data = ppu.extract_trig_properties(
trial_dict,
all_off_trigs,
slide_dict,
segment_dict,
keys
)

# Max BestNR values in each trial: these are stored in a dictionary keyed
# by slide_id, as arrays indexed by trial number
background = {k: np.zeros(len(v)) for k,v in trial_dict.items()}
for slide_id in slide_dict:
trig_times = trig_data[keys[0]][slide_id]
for j, trial in enumerate(trial_dict[slide_id]):
trial_cut = (trial[0] <= trig_time[slide_id])\
& (trig_time[slide_id] < trial[1])
# True whenever the trigger is in the trial
trial_cut = (trial[0] <= trig_times) & (trig_times < trial[1])
# Move on if nothing was in the trial
if not trial_cut.any():
continue
# Max BestNR
time_veto_max_bestnr[slide_id][j] = \
max(trig_bestnr[slide_id][trial_cut])
background[slide_id][j] = max(trig_data[keys[1]][slide_id][trial_cut])

# Max and median values of reweighted SNR,
# and sorted (loudest in trial) reweighted SNR values
max_bestnr, median_bestnr, sorted_bkgd =\
ppu.max_median_stat(slide_dict, background, trig_data[keys[1]],
total_trials)
assert total_trials == len(sorted_bkgd)

logging.info("SNR and bestNR maxima calculated.")
logging.info("Background bestNR calculated.")

# Output details of loudest offsouce triggers
# Output details of loudest offsouce triggers: only triggers compatible
# with the trial_dict are considered
offsource_trigs = []
sorted_trigs = ppu.sort_trigs(trial_dict, trigs, slide_dict, segment_dict)
sorted_off_trigs = ppu.sort_trigs(
trial_dict,
all_off_trigs,
slide_dict,
segment_dict
)
for slide_id in slide_dict:
offsource_trigs.extend(zip(trig_bestnr[slide_id], sorted_trigs[slide_id]))
offsource_trigs.extend(
zip(trig_data[keys[1]][slide_id], sorted_off_trigs[slide_id])
)
offsource_trigs.sort(key=lambda element: element[0])
offsource_trigs.reverse()

# ==========================
# Print loudest SNRs to file
# THIS OUTPUT FILE IS CURRENTLY UNUSED - MAYBE DELETE?
# Note: the only new info from above is the median SNR, bestnr
# and loudest SNR, so could just add this to the above's caption.
# ==========================
max_bestnr, _, full_time_veto_max_bestnr =\
ppu.max_median_stat(slide_dict, time_veto_max_bestnr, trig_bestnr,
total_trials)

# ==========================
# Calculate template chirp masses from bank
# ==========================
# Calculate chirp masses of templates in bank
logging.info('Reading template chirp masses')
with HFile(opts.bank_file, 'r') as bank_file:
template_mchirps = mchirp_from_mass1_mass2(
Expand All @@ -261,9 +272,10 @@ with HFile(opts.bank_file, 'r') as bank_file:
if onsource_file:

logging.info("Processing onsource.")
# Get onsouce_triggers (apply reweighted SNR cut)
on_trigs = ppu.load_triggers(onsource_file, ifos, vetoes,
rw_snr_threshold=opts.newsnr_threshold)
# Load onsoource triggers (apply reweighted SNR cut, not vetoes)
on_trigs = ppu.load_data(onsource_file, ifos, data_tag=None,
rw_snr_threshold=opts.newsnr_threshold,
slide_id=0)

# Calculate chirp mass values
on_mchirp = template_mchirps[on_trigs['network/template_id']]
Expand All @@ -288,65 +300,57 @@ if onsource_file:
logging.info("Onsource analysed.")

if loud_on_bestnr_idx is not None:
num_trials_louder = 0
tot_off_snr = np.array([])
for slide_id in slide_dict:
num_trials_louder += sum(time_veto_max_bestnr[slide_id] >
loud_on_bestnr)
tot_off_snr = np.concatenate([tot_off_snr,
time_veto_max_bestnr[slide_id]])
#fap_test = sum(tot_off_snr > loud_on_bestnr)/total_trials
loud_on_fap = num_trials_louder/total_trials
loud_on_fap = sum(sorted_bkgd > loud_on_bestnr) / total_trials

else:
tot_off_snr = np.array([])
for slide_id in slide_dict:
tot_off_snr = np.concatenate([tot_off_snr,
time_veto_max_bestnr[slide_id]])
med_snr = np.median(tot_off_snr)
#loud_on_fap = sum(tot_off_snr > med_snr)/total_trials

# =======================
# Post-process injections
# =======================

sites = [ifo[0] for ifo in ifos]

# injs contains the information about found/missed injections AND triggers
# Triggers and injections are discared if at vetoed times and/or below
# Reweighted SNR thrshold
injs = ppu.load_triggers(found_missed_file, ifos, vetoes,
rw_snr_threshold=opts.newsnr_threshold)

logging.info("Missed/found injections/triggers loaded.")
# injs contains found/missed injections AND triggers they generated
# The reweighted SNR cut is applied, vetoes are not
injs = ppu.load_data(found_missed_file, ifos, data_tag='injs',
rw_snr_threshold=opts.newsnr_threshold,
slide_id=0)

# Gather injections that were not missed
found_inj = {}
for k in injs.keys():
if 'missed' not in k:
found_inj[k] = injs[k]

# Separate them in found surviving vetoes and found but vetoed
found_after_vetoes, vetoed, *_ = ppu.apply_vetoes_to_found_injs(
found_missed_file,
found_inj,
ifos,
veto_file=veto_file
)

# Calculate quantities not included in trigger files, such as chirp mass
found_trig_mchirp = template_mchirps[injs['network/template_id']]

found_trig_mchirp = template_mchirps[found_after_vetoes['network/template_id']]

# Construct conditions for injection:
# 1) found louder than background,
zero_fap = np.zeros(len(injs['network/end_time_gc'])).astype(bool)
zero_fap_cut = injs['network/reweighted_snr'][:] > max_bestnr
# 1) found (surviving vetoes) louder than background,
zero_fap = np.zeros(len(found_after_vetoes['network/end_time_gc'])).astype(bool)
zero_fap_cut = found_after_vetoes['network/reweighted_snr'] > max_bestnr
zero_fap = zero_fap | (zero_fap_cut)

# 2) found (bestnr > 0) but not louder than background (non-zero FAP)
nonzero_fap = ~zero_fap & (injs['network/reweighted_snr'] != 0)
# 2) found (bestnr > 0, and surviving vetoes) but not louder than background
nonzero_fap = ~zero_fap & (found_after_vetoes['network/reweighted_snr'] != 0)

# 3) missed after being recovered (i.e., vetoed) are not used here
# missed = (~zero_fap) & (~nonzero_fap)
# 3) missed after being recovered (i.e., vetoed) are in vetoed

# Non-zero FAP triggers (g_ifar)
g_ifar = {}
g_ifar['bestnr'] = injs['network/reweighted_snr'][nonzero_fap]
g_ifar['bestnr'] = found_after_vetoes['network/reweighted_snr'][nonzero_fap]
g_ifar['stat'] = np.zeros([len(g_ifar['bestnr'])])
for ix, (mc, bestnr) in \
enumerate(zip(found_trig_mchirp[nonzero_fap], g_ifar['bestnr'])):
g_ifar['stat'][ix] = (full_time_veto_max_bestnr > bestnr).sum()
g_ifar['stat'][ix] = (sorted_bkgd > bestnr).sum()
g_ifar['stat'] = g_ifar['stat'] / total_trials

# Set the sigma values
inj_sigma = {ifo: injs[f'{ifo}/sigmasq'][:] for ifo in ifos}
inj_sigma = {ifo: found_after_vetoes[f'{ifo}/sigmasq'][:] for ifo in ifos}
# If the sigmasqs are not populated, we can still do calibration errors,
# but only in the 1-detector case
for ifo in ifos:
Expand All @@ -365,9 +369,9 @@ f_resp = {}
for ifo in ifos:
antenna = Detector(ifo)
f_resp[ifo] = ppu.get_antenna_responses(antenna,
injs['found/ra'][:],
injs['found/dec'][:],
injs['found/tc'][:])
found_after_vetoes['found/ra'][:],
found_after_vetoes['found/dec'][:],
found_after_vetoes['found/tc'][:])

inj_sigma_mult = (np.asarray(list(inj_sigma.values())) *
np.asarray(list(f_resp.values())))
Expand All @@ -380,12 +384,12 @@ inj_sigma_mean = {}
for ifo in ifos:
inj_sigma_mean[ifo] = ((inj_sigma[ifo]*f_resp[ifo])/inj_sigma_tot).mean()

logging.info("%d found injections analysed.", len(injs['found/tc']))

# Process missed injections (injs['missed'])
logging.info("%d missed injections analysed.", len(injs['missed/tc']))
msg = f"{len(found_after_vetoes['found/tc'])} injections found and surviving "
msg += f"vetoes and {len(injs['missed/tc'])} missed injections analysed."
logging.info(msg)

# Create new set of injections for efficiency calculations
# Create new set of injections for efficiency calculations:
# these are as many as the original injections
total_injs = len(injs['found/distance']) + len(injs['missed/distance'])
long_inj = {}
long_inj['dist'] = stats.uniform.rvs(size=total_injs) * \
Expand All @@ -411,7 +415,7 @@ for key in ['mc', 'no_mc']:
found_on_bestnr[key] = np.zeros(num_dist_bins_plus_one)

# Construct FAP list for all found injections
inj_fap = np.zeros(len(injs['found/distance']))
inj_fap = np.zeros(len(found_after_vetoes['found/distance']))
inj_fap[nonzero_fap] = g_ifar['stat']

# Calculate the amplitude error
Expand All @@ -434,10 +438,20 @@ logging.info("Calibration amplitude uncertainty calculated.")
# NOTE: the loop on num_mc_injs would fill up the *_inj['dist_mc']'s at the
# same time, so filling them up sequentially will vary the numbers a little
# (this is an MC, order of operations matters!)
found_inj_dist_mc = ppu.mc_cal_wf_errs(num_mc_injs, injs['found/distance'],
cal_error, wav_err, max_dc_cal_error)
missed_inj_dist_mc = ppu.mc_cal_wf_errs(num_mc_injs, injs['missed/distance'],
cal_error, wav_err, max_dc_cal_error)
found_inj_dist_mc = ppu.mc_cal_wf_errs(
num_mc_injs,
found_after_vetoes['found/distance'],
cal_error,
wav_err,
max_dc_cal_error
)
missed_inj_dist_mc = ppu.mc_cal_wf_errs(
num_mc_injs,
np.concatenate((vetoed['found/distance'],injs['missed/distance'])),
cal_error,
wav_err,
max_dc_cal_error
)
long_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs, long_inj['dist'],
cal_error, wav_err, max_dc_cal_error)

Expand All @@ -452,32 +466,32 @@ else:

distance_count = np.zeros(len(dist_bins))

found_trig_max_bestnr = np.empty(len(injs['network/event_id']))
found_trig_max_bestnr = np.empty(len(found_after_vetoes['network/event_id']))
found_trig_max_bestnr.fill(max_bestnr)

max_bestnr_cut = (injs['network/reweighted_snr'] > found_trig_max_bestnr)
max_bestnr_cut = (found_after_vetoes['network/reweighted_snr'] > found_trig_max_bestnr)

# Check louder than on source
found_trig_loud_on_bestnr = np.empty(len(injs['network/event_id']))
found_trig_loud_on_bestnr = np.empty(len(found_after_vetoes['network/event_id']))
if onsource_file:
found_trig_loud_on_bestnr.fill(loud_on_bestnr)
else:
found_trig_loud_on_bestnr.fill(med_snr)
on_bestnr_cut = injs['network/reweighted_snr'] > found_trig_loud_on_bestnr
found_trig_loud_on_bestnr.fill(median_bestnr)
on_bestnr_cut = found_after_vetoes['network/reweighted_snr'] > found_trig_loud_on_bestnr

# Check whether injection is found for the purposes of exclusion
# distance calculation.
# Found: if louder than all on source
# Missed: if not louder than loudest on source
found_excl = on_bestnr_cut & (more_sig_than_onsource) & \
(injs['network/reweighted_snr'] != 0)
(found_after_vetoes['network/reweighted_snr'] != 0)
# If not missed, double check bestnr against nearby triggers
near_test = np.zeros((found_excl).sum()).astype(bool)
for j, (t, bestnr) in enumerate(zip(injs['found/tc'][found_excl],
injs['network/reweighted_snr'][found_excl])):
for j, (t, bestnr) in enumerate(zip(found_after_vetoes['found/tc'][found_excl],
found_after_vetoes['network/reweighted_snr'][found_excl])):
# 0 is the zero-lag timeslide
near_bestnr = \
trig_bestnr[0][np.abs(trig_time[0]-t) < cluster_window]
trig_data[keys[1]][0][np.abs(trig_data[keys[0]][0]-t) < cluster_window]
near_test[j] = ~((near_bestnr * glitch_check_fac > bestnr).any())
# Apply the local test
c = 0
Expand Down Expand Up @@ -528,6 +542,7 @@ logging.info("Found/missed injection efficiency calculations completed.")
# ==========
# Make plots
# ==========
logging.info("Plotting.")
# Calculate distances (horizontal axis) as means
dist_plot_vals = [np.asarray(dist_bin).mean() for dist_bin in dist_bins]

Expand Down Expand Up @@ -578,7 +593,7 @@ yerr_low, yerr_high, fraction_mc = \
red_efficiency = (fraction_mc) - (yerr_low) * scipy.stats.norm.isf(0.1)

# Calculate and save to disk 50% and 90% exclusion distances
# excl_dist dictionary contains 50% and 90% exclusion distances
# excl_dist dictionary contains 50% and 90% exclusion distances
excl_dist = {}
for percentile in [50, 90]:
eff_idx = np.where(red_efficiency < (percentile / 100.))[0]
Expand Down
Loading

0 comments on commit 1929bc2

Please sign in to comment.