diff --git a/pycbc/results/pygrb_postprocessing_utils.py b/pycbc/results/pygrb_postprocessing_utils.py index dea83e196ea..4f36908d7b3 100644 --- a/pycbc/results/pygrb_postprocessing_utils.py +++ b/pycbc/results/pygrb_postprocessing_utils.py @@ -30,7 +30,8 @@ import h5py from scipy import stats -import ligo.segments as segments +from ligo import segments +from ligo.segments.utils import fromsegwizard from pycbc.events.coherent import reweightedsnr_cut from pycbc.events import veto from pycbc import add_common_pycbc_options @@ -38,8 +39,6 @@ logger = logging.getLogger('pycbc.results.pygrb_postprocessing_utils') -from ligo.segments.utils import fromsegwizard - # ============================================================================= # Arguments functions: @@ -374,8 +373,6 @@ def load_data(input_file, ifos, rw_snr_threshold=None, data_tag=None, def apply_vetoes_to_found_injs(found_missed_file, found_injs, ifos, veto_file=None, keys=None): """Separate injections surviving vetoes from vetoed injections. - THIS IS ESSENTIALLY AN EMPTY PLACE HOLDER AT THE MOMENT: IT RETURNS - THE INJECTIONS GIVEN IN INPUT, WITHOUT APPLYING VETOES. Parameters ---------- @@ -412,6 +409,28 @@ def apply_vetoes_to_found_injs(found_missed_file, found_injs, ifos, found_idx = numpy.arange(len(found_injs[ifos[0]+'/end_time'][:])) veto_idx = numpy.array([], dtype=numpy.int64) + if veto_file: + logging.info("Applying data vetoes to found injections...") + for ifo in ifos: + inj_time = found_injs[ifo+'/end_time'][:] + idx, _ = veto.indices_outside_segments(inj_time, [veto_file], ifo, None) + veto_idx = numpy.append(veto_idx, idx) + logging.info("%d injections vetoed due to %s.", len(idx), ifo) + idx, _ = veto.indices_within_segments(inj_time, [veto_file], ifo, None) + found_idx = numpy.intersect1d(found_idx, idx) + veto_idx = numpy.unique(veto_idx) + logging.info("%d injections vetoed.", len(veto_idx)) + logging.info("%d injections surviving vetoes.", len(found_idx)) + + found_after_vetoes = {} + missed_after_vetoes = {} + for key in keep_keys: + if key == 'network/coincident_snr': + found_injs[key] = get_coinc_snr(found_injs) + if isinstance(found_injs[key], numpy.ndarray): + found_after_vetoes[key] = found_injs[key][found_idx] + missed_after_vetoes[key] = found_injs[key][veto_idx] + return found_after_vetoes, missed_after_vetoes, found_idx, veto_idx @@ -510,7 +529,7 @@ def extract_trig_properties(trial_dict, trigs, slide_dict, seg_dict, keys): # Sort the triggers into each slide sorted_trigs = sort_trigs(trial_dict, trigs, slide_dict, seg_dict) - n_surviving_trigs = sum([len(i) for i in sorted_trigs.values()]) + n_surviving_trigs = sum(len(i) for i in sorted_trigs.values()) msg = f"{n_surviving_trigs} triggers found within the trials dictionary " msg += "and sorted." logging.info(msg) @@ -676,7 +695,7 @@ def construct_trials(seg_files, seg_dict, ifos, slide_dict, veto_file, iter_int += 1 - total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict]) + total_trials = sum(len(trial_dict[slide_id]) for slide_id in slide_dict) logging.info("%d trials generated.", total_trials) return trial_dict, total_trials @@ -701,8 +720,8 @@ def sort_stat(time_veto_max_stat): def max_median_stat(slide_dict, time_veto_max_stat, trig_stat, total_trials): """Return maximum and median of trig_stat and sorted time_veto_max_stat""" - max_stat = max([trig_stat[slide_id].max() if trig_stat[slide_id].size - else 0 for slide_id in slide_dict]) + max_stat = max(trig_stat[slide_id].max() if trig_stat[slide_id].size + else 0 for slide_id in slide_dict) full_time_veto_max_stat = sort_stat(time_veto_max_stat)