From fa6bc1fec8a2ff2d76c7b286b16059c37a7748d0 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhuang Date: Sat, 4 Jul 2020 01:42:35 +0200 Subject: [PATCH] ChainReader Refactoring (#2815) * Fixes #2814 * changes in ChainReader: * new `_read_frame` and `__iter__` are now properly connected (previously, they moved independently, which could lead to the ChainReader being in an inconsistent state) * remove `_chained_iterator`. * add `__next__` * add an extra attribute `__current_frame` to internally monitor absolute current frame * add tests * update CHANGELOG --- package/CHANGELOG | 6 +- package/MDAnalysis/coordinates/chain.py | 62 +++++++++---------- .../coordinates/test_chainreader.py | 5 ++ 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/package/CHANGELOG b/package/CHANGELOG index 59f71ae92ee..08bd8ff7b53 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -14,7 +14,7 @@ The rules for this file: ------------------------------------------------------------------------------ ??/??/20 richardjgowers, IAlibay, orbeckst, tylerjereddy, jbarnoud, - VOD555, lilyminium + yuxuanzhuang, VOD555, lilyminium * 1.0.1 @@ -22,6 +22,8 @@ Fixes * Development status changed from beta to mature (Issue #2773) * pip installation only requests Python 2.7-compatible packages (#2736) * Testsuite does not use any more matplotlib.use('agg') (#2191) + * In ChainReader, read_frame does not trigger change of iterating position. + (Issue #2723, PR #2815) * rdf.InterRDF_s density keyword documented and now gives correct results for density=True; the keyword was available since 0.19.0 but with incorrect semantics and not documented and did not produce correct results (Issue @@ -47,7 +49,7 @@ Fixes (Issue #2565) * Made NoDataError a subclass of ValueError *and* AttributeError (Issue #2635) - * Fixed select_atoms("around 0.0 ...") selections and capped_distance + * Fixed select_atoms("around 0.0 ...") selections and capped_distance causing a segfault (Issue #2656 PR #2665) * `PDBWriter` writes unitary `CRYST1` record (cubic box with sides of 1 Å) when `u.dimensions` is `None` or `np.zeros(6)` (Issue #2679, PR #2685) diff --git a/package/MDAnalysis/coordinates/chain.py b/package/MDAnalysis/coordinates/chain.py index 2a7604f78e2..b2e06d55495 100644 --- a/package/MDAnalysis/coordinates/chain.py +++ b/package/MDAnalysis/coordinates/chain.py @@ -38,7 +38,6 @@ .. automethod:: _get .. automethod:: _get_same .. automethod:: _read_frame - .. automethod:: _chained_iterator """ from __future__ import absolute_import @@ -46,7 +45,6 @@ import warnings import os.path -import itertools import bisect import copy @@ -143,7 +141,7 @@ def filter_times(times, dt): def check_allowed_filetypes(readers, allowed): """ - Make a check that all readers have the same filetype and are of the + Make a check that all readers have the same filetype and are of the allowed files types. Throws Exception on failure. Parameters @@ -151,7 +149,7 @@ def check_allowed_filetypes(readers, allowed): readers : list of MDA readers allowed : list of allowed formats """ - classname = type(readers[0]) + classname = type(readers[0]) only_one_reader = np.all([isinstance(r, classname) for r in readers]) if not only_one_reader: readernames = [type(r) for r in readers] @@ -160,7 +158,7 @@ def check_allowed_filetypes(readers, allowed): "Found: {}".format(readernames)) if readers[0].format not in allowed: raise NotImplementedError("ChainReader: continuous=True only " - "supported for formats: {}".format(allowed)) + "supported for formats: {}".format(allowed)) class ChainReader(base.ProtoReader): @@ -265,7 +263,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs): kwargs['dt'] = dt self.readers = [core.reader(filename, **kwargs) for filename in filenames] - self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn for fn in filenames]) + self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn + for fn in filenames]) # pointer to "active" trajectory index into self.readers self.__active_reader_index = 0 @@ -292,9 +291,6 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs): self.dts = np.array(self._get('dt')) self.total_times = self.dts * n_frames - #: source for trajectories frame (fakes trajectory) - self.__chained_trajectories_iter = None - # calculate new start_frames to have a time continuous trajectory. if continuous: check_allowed_filetypes(self.readers, ['XTC', 'TRR']) @@ -348,7 +344,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs): # check for interleaving r1[1] if r1_start_time < start_time < r1.time: - raise RuntimeError("ChainReader: Interleaving not supported with continuous=True.") + raise RuntimeError("ChainReader: Interleaving not supported " + "with continuous=True.") # find end where trajectory was restarted from for ts in r1[::-1]: @@ -441,8 +438,6 @@ def copy(self): new.ts = self.ts.copy() return new - - # attributes that can change with the current reader @property def filename(self): @@ -563,49 +558,39 @@ def _read_frame(self, frame): # update Timestep self.ts = self.active_reader.ts self.ts.frame = frame # continuous frames, 0-based + self.__current_frame = frame return self.ts - def _chained_iterator(self): - """Iterator that presents itself as a chained trajectory.""" - self._rewind() # must rewind all readers - for i in range(self.n_frames): - j, f = self._get_local_frame(i) - self.__activate_reader(j) - self.ts = self.active_reader[f] - self.ts.frame = i - yield self.ts def _read_next_timestep(self, ts=None): - self.ts = next(self.__chained_trajectories_iter) - return self.ts + if ts is None: + ts = self.ts + ts = self.__next__() + return ts def rewind(self): """Set current frame to the beginning.""" self._rewind() - self.__chained_trajectories_iter = self._chained_iterator() - # set time step for frame 1 - self.ts = next(self.__chained_trajectories_iter) def _rewind(self): """Internal method: Rewind trajectories themselves and trj pointer.""" + self.__current_frame = -1 self._apply('rewind') - self.__activate_reader(0) + self.__next__() def close(self): self._apply('close') def __iter__(self): - """Generator for all frames, starting at frame 1.""" - self._rewind() + """Generator for all frames, starting at frame 0.""" + self.__current_frame = -1 # start from first frame - self.__chained_trajectories_iter = self._chained_iterator() - for ts in self.__chained_trajectories_iter: - yield ts + return self def __repr__(self): if len(self.filenames) > 3: fnames = "{fname} and {nfanmes} more".format( - fname=os.path.basename(self.filenames[0]), + fname=os.path.basename(self.filenames[0]), nfanmes=len(self.filenames) - 1) else: fnames = ", ".join([os.path.basename(fn) for fn in self.filenames]) @@ -658,3 +643,14 @@ def _apply_transformations(self, ts): # to avoid applying the same transformations multiple times on each frame return ts + + def __next__(self): + if self.__current_frame < self.n_frames - 1: + j, f = self._get_local_frame(self.__current_frame + 1) + self.__activate_reader(j) + self.ts = self.active_reader[f] + self.ts.frame = self.__current_frame + 1 + self.__current_frame += 1 + return self.ts + else: + raise StopIteration() diff --git a/testsuite/MDAnalysisTests/coordinates/test_chainreader.py b/testsuite/MDAnalysisTests/coordinates/test_chainreader.py index 3244ecabcde..99102c64c9e 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_chainreader.py +++ b/testsuite/MDAnalysisTests/coordinates/test_chainreader.py @@ -99,6 +99,11 @@ def test_frame_numbering(self, universe): universe.trajectory[98] # index is 0-based and frames are 0-based assert_equal(universe.trajectory.frame, 98, "wrong frame number") + def test_next_after_frame_numbering(self, universe): + universe.trajectory[98] # index is 0-based and frames are 0-based + universe.trajectory.next() + assert_equal(universe.trajectory.frame, 99, "wrong frame number") + def test_frame(self, universe): universe.trajectory[0] coord0 = universe.atoms.positions.copy()