diff --git a/newsfragments/709.feature b/newsfragments/709.feature new file mode 100644 index 000000000..f3ecaa677 --- /dev/null +++ b/newsfragments/709.feature @@ -0,0 +1 @@ +Allows stepping through XTC streams at specific indices provided by a text file diff --git a/src/dxtbx/format/FormatXTC.py b/src/dxtbx/format/FormatXTC.py index e405c6baf..3016ce4fd 100644 --- a/src/dxtbx/format/FormatXTC.py +++ b/src/dxtbx/format/FormatXTC.py @@ -3,6 +3,7 @@ import functools import sys import time +from itertools import groupby import numpy as np import serialtbx.detector.xtc @@ -32,11 +33,15 @@ psana = None locator_str = """ + hits_file = None + .type = str + .help = path to a file where each line is 2 numbers separated by a space, a run index, and an event index in the XTC stream experiment = None .type = str .help = Experiment identifier, e.g. mfxo1916 run = None - .type = ints + .type = int + .multiple = True .help = Run number or a list of runs to process mode = idx .type = str @@ -147,6 +152,7 @@ def __init__(self, image_file, **kwargs): self._ds = FormatXTC._get_datasource(image_file, self.params) self._evr = None + self._load_hit_indices() self.populate_events() self._cached_psana_detectors = {} @@ -162,6 +168,17 @@ def __init__(self, image_file, **kwargs): else: self._spectrum_pedestal = None + def _load_hit_indices(self): + self._hit_inds = None + if self.params.hits_file is not None: + assert self.params.mode == "idx" + hits = np.loadtxt(self.params.hits_file, int) + hits = list(map(tuple, hits)) + key = lambda x: x[0] + gb = groupby(sorted(hits, key=key), key=key) + # dictionary where key is run number, and vals are indices of hits + self._hit_inds = {r:[ind for _,ind in group] for r,group in gb} + @staticmethod def understand(image_file): """Extracts the datasource and detector_address from the image_file and then feeds it to PSANA @@ -229,18 +246,27 @@ def populate_events(self): self.run_mapping = {} if self.params.mode == "idx": - for run in self._psana_runs.values(): + for run_num, run in self._psana_runs.items(): times = run.times() + if self._hit_inds is not None and run_num not in self._hit_inds: + continue + if self._hit_inds is not None and run_num in self._hit_inds: + temp = [] + for i_hit in self._hit_inds[run_num]: + temp.append( times[i_hit] ) + times = tuple(temp) if ( self.params.filter.required_present_codes or self.params.filter.required_absent_codes ) and self.params.filter.pre_filter: times = [t for t in times if self.filter_event(run.event(t))] - self.run_mapping[run.run()] = ( + + self.run_mapping[run_num] = ( len(self.times), len(self.times) + len(times), run, ) + self.times.extend(times) self.n_images = len(self.times)