From e878e5de312ed8da0ffb36ff0e9e7e4a084556c8 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhuang Date: Sat, 4 Jul 2020 01:42:35 +0200 Subject: [PATCH] ChainReader Refactoring (#2815) * Fixes #2814 * changes in ChainReader: * new `_read_frame` and `__iter__` are now properly connected (previously, they moved independently, which could lead to the ChainReader being in an inconsistent state) * remove `_chained_iterator`. * add `__next__` * add an extra attribute `__current_frame` to internally monitor absolute current frame * add tests * update CHANGELOG --- package/CHANGELOG | 5 +- package/MDAnalysis/coordinates/chain.py | 62 +++++++++---------- .../coordinates/test_chainreader.py | 5 ++ 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/package/CHANGELOG b/package/CHANGELOG index 26813db9495..678f0fcf293 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -14,7 +14,8 @@ The rules for this file: ------------------------------------------------------------------------------ ??/??/?? richardjgowers, IAlibay, hmacdope, orbeckst, cbouy, lilyminium, - daveminh, jbarnoud + daveminh, jbarnoud, yuxuanzhuang + * 2.0.0 @@ -25,6 +26,8 @@ Fixes * TOPParser no longer guesses elements when missing atomic number records (Issues #2449, #2651) * Testsuite does not any more matplotlib.use('agg') (#2191) + * In ChainReader, read_frame does not trigger change of iterating position. + (Issue #2723, PR #2815) Enhancements * Added the RDKitParser which creates a `core.topology.Topology` object from diff --git a/package/MDAnalysis/coordinates/chain.py b/package/MDAnalysis/coordinates/chain.py index c0c0fea3e53..d76f34a1140 100644 --- a/package/MDAnalysis/coordinates/chain.py +++ b/package/MDAnalysis/coordinates/chain.py @@ -38,13 +38,11 @@ .. automethod:: _get .. automethod:: _get_same .. automethod:: _read_frame - .. automethod:: _chained_iterator """ import warnings import os.path -import itertools import bisect import copy @@ -141,7 +139,7 @@ def filter_times(times, dt): def check_allowed_filetypes(readers, allowed): """ - Make a check that all readers have the same filetype and are of the + Make a check that all readers have the same filetype and are of the allowed files types. Throws Exception on failure. Parameters @@ -149,7 +147,7 @@ def check_allowed_filetypes(readers, allowed): readers : list of MDA readers allowed : list of allowed formats """ - classname = type(readers[0]) + classname = type(readers[0]) only_one_reader = np.all([isinstance(r, classname) for r in readers]) if not only_one_reader: readernames = [type(r) for r in readers] @@ -158,7 +156,7 @@ def check_allowed_filetypes(readers, allowed): "Found: {}".format(readernames)) if readers[0].format not in allowed: raise NotImplementedError("ChainReader: continuous=True only " - "supported for formats: {}".format(allowed)) + "supported for formats: {}".format(allowed)) class ChainReader(base.ProtoReader): @@ -263,7 +261,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs): kwargs['dt'] = dt self.readers = [core.reader(filename, **kwargs) for filename in filenames] - self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn for fn in filenames]) + self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn + for fn in filenames]) # pointer to "active" trajectory index into self.readers self.__active_reader_index = 0 @@ -290,9 +289,6 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs): self.dts = np.array(self._get('dt')) self.total_times = self.dts * n_frames - #: source for trajectories frame (fakes trajectory) - self.__chained_trajectories_iter = None - # calculate new start_frames to have a time continuous trajectory. if continuous: check_allowed_filetypes(self.readers, ['XTC', 'TRR']) @@ -346,7 +342,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs): # check for interleaving r1[1] if r1_start_time < start_time < r1.time: - raise RuntimeError("ChainReader: Interleaving not supported with continuous=True.") + raise RuntimeError("ChainReader: Interleaving not supported " + "with continuous=True.") # find end where trajectory was restarted from for ts in r1[::-1]: @@ -439,8 +436,6 @@ def copy(self): new.ts = self.ts.copy() return new - - # attributes that can change with the current reader @property def filename(self): @@ -561,49 +556,39 @@ def _read_frame(self, frame): # update Timestep self.ts = self.active_reader.ts self.ts.frame = frame # continuous frames, 0-based + self.__current_frame = frame return self.ts - def _chained_iterator(self): - """Iterator that presents itself as a chained trajectory.""" - self._rewind() # must rewind all readers - for i in range(self.n_frames): - j, f = self._get_local_frame(i) - self.__activate_reader(j) - self.ts = self.active_reader[f] - self.ts.frame = i - yield self.ts def _read_next_timestep(self, ts=None): - self.ts = next(self.__chained_trajectories_iter) - return self.ts + if ts is None: + ts = self.ts + ts = self.__next__() + return ts def rewind(self): """Set current frame to the beginning.""" self._rewind() - self.__chained_trajectories_iter = self._chained_iterator() - # set time step for frame 1 - self.ts = next(self.__chained_trajectories_iter) def _rewind(self): """Internal method: Rewind trajectories themselves and trj pointer.""" + self.__current_frame = -1 self._apply('rewind') - self.__activate_reader(0) + self.__next__() def close(self): self._apply('close') def __iter__(self): - """Generator for all frames, starting at frame 1.""" - self._rewind() + """Generator for all frames, starting at frame 0.""" + self.__current_frame = -1 # start from first frame - self.__chained_trajectories_iter = self._chained_iterator() - for ts in self.__chained_trajectories_iter: - yield ts + return self def __repr__(self): if len(self.filenames) > 3: fnames = "{fname} and {nfanmes} more".format( - fname=os.path.basename(self.filenames[0]), + fname=os.path.basename(self.filenames[0]), nfanmes=len(self.filenames) - 1) else: fnames = ", ".join([os.path.basename(fn) for fn in self.filenames]) @@ -656,3 +641,14 @@ def _apply_transformations(self, ts): # to avoid applying the same transformations multiple times on each frame return ts + + def __next__(self): + if self.__current_frame < self.n_frames - 1: + j, f = self._get_local_frame(self.__current_frame + 1) + self.__activate_reader(j) + self.ts = self.active_reader[f] + self.ts.frame = self.__current_frame + 1 + self.__current_frame += 1 + return self.ts + else: + raise StopIteration() diff --git a/testsuite/MDAnalysisTests/coordinates/test_chainreader.py b/testsuite/MDAnalysisTests/coordinates/test_chainreader.py index c8d1f9e4b48..27d6989763d 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_chainreader.py +++ b/testsuite/MDAnalysisTests/coordinates/test_chainreader.py @@ -96,6 +96,11 @@ def test_frame_numbering(self, universe): universe.trajectory[98] # index is 0-based and frames are 0-based assert_equal(universe.trajectory.frame, 98, "wrong frame number") + def test_next_after_frame_numbering(self, universe): + universe.trajectory[98] # index is 0-based and frames are 0-based + universe.trajectory.next() + assert_equal(universe.trajectory.frame, 99, "wrong frame number") + def test_frame(self, universe): universe.trajectory[0] coord0 = universe.atoms.positions.copy()