Skip to content

Commit

Permalink
ChainReader Refactoring (#2815)
Browse files Browse the repository at this point in the history
* Fixes #2814
* changes in ChainReader:
   * new `_read_frame` and `__iter__` are now properly connected
      (previously, they moved independently, which could lead to the ChainReader
      being in an inconsistent state)
   * remove `_chained_iterator`.
   * add `__next__`
   * add an extra attribute `__current_frame` to internally monitor absolute current frame
* add tests
* update CHANGELOG
  • Loading branch information
yuxuanzhuang authored Jul 3, 2020
1 parent 61e236d commit e878e5d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
5 changes: 4 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ The rules for this file:

------------------------------------------------------------------------------
??/??/?? richardjgowers, IAlibay, hmacdope, orbeckst, cbouy, lilyminium,
daveminh, jbarnoud
daveminh, jbarnoud, yuxuanzhuang



* 2.0.0
Expand All @@ -25,6 +26,8 @@ Fixes
* TOPParser no longer guesses elements when missing atomic number records
(Issues #2449, #2651)
* Testsuite does not any more matplotlib.use('agg') (#2191)
* In ChainReader, read_frame does not trigger change of iterating position.
(Issue #2723, PR #2815)

Enhancements
* Added the RDKitParser which creates a `core.topology.Topology` object from
Expand Down
62 changes: 29 additions & 33 deletions package/MDAnalysis/coordinates/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@
.. automethod:: _get
.. automethod:: _get_same
.. automethod:: _read_frame
.. automethod:: _chained_iterator
"""
import warnings

import os.path
import itertools
import bisect
import copy

Expand Down Expand Up @@ -141,15 +139,15 @@ def filter_times(times, dt):

def check_allowed_filetypes(readers, allowed):
"""
Make a check that all readers have the same filetype and are of the
Make a check that all readers have the same filetype and are of the
allowed files types. Throws Exception on failure.
Parameters
----------
readers : list of MDA readers
allowed : list of allowed formats
"""
classname = type(readers[0])
classname = type(readers[0])
only_one_reader = np.all([isinstance(r, classname) for r in readers])
if not only_one_reader:
readernames = [type(r) for r in readers]
Expand All @@ -158,7 +156,7 @@ def check_allowed_filetypes(readers, allowed):
"Found: {}".format(readernames))
if readers[0].format not in allowed:
raise NotImplementedError("ChainReader: continuous=True only "
"supported for formats: {}".format(allowed))
"supported for formats: {}".format(allowed))


class ChainReader(base.ProtoReader):
Expand Down Expand Up @@ -263,7 +261,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
kwargs['dt'] = dt
self.readers = [core.reader(filename, **kwargs)
for filename in filenames]
self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn for fn in filenames])
self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn
for fn in filenames])
# pointer to "active" trajectory index into self.readers
self.__active_reader_index = 0

Expand All @@ -290,9 +289,6 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
self.dts = np.array(self._get('dt'))
self.total_times = self.dts * n_frames

#: source for trajectories frame (fakes trajectory)
self.__chained_trajectories_iter = None

# calculate new start_frames to have a time continuous trajectory.
if continuous:
check_allowed_filetypes(self.readers, ['XTC', 'TRR'])
Expand Down Expand Up @@ -346,7 +342,8 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
# check for interleaving
r1[1]
if r1_start_time < start_time < r1.time:
raise RuntimeError("ChainReader: Interleaving not supported with continuous=True.")
raise RuntimeError("ChainReader: Interleaving not supported "
"with continuous=True.")

# find end where trajectory was restarted from
for ts in r1[::-1]:
Expand Down Expand Up @@ -439,8 +436,6 @@ def copy(self):
new.ts = self.ts.copy()
return new



# attributes that can change with the current reader
@property
def filename(self):
Expand Down Expand Up @@ -561,49 +556,39 @@ def _read_frame(self, frame):
# update Timestep
self.ts = self.active_reader.ts
self.ts.frame = frame # continuous frames, 0-based
self.__current_frame = frame
return self.ts

def _chained_iterator(self):
"""Iterator that presents itself as a chained trajectory."""
self._rewind() # must rewind all readers
for i in range(self.n_frames):
j, f = self._get_local_frame(i)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = i
yield self.ts

def _read_next_timestep(self, ts=None):
self.ts = next(self.__chained_trajectories_iter)
return self.ts
if ts is None:
ts = self.ts
ts = self.__next__()
return ts

def rewind(self):
"""Set current frame to the beginning."""
self._rewind()
self.__chained_trajectories_iter = self._chained_iterator()
# set time step for frame 1
self.ts = next(self.__chained_trajectories_iter)

def _rewind(self):
"""Internal method: Rewind trajectories themselves and trj pointer."""
self.__current_frame = -1
self._apply('rewind')
self.__activate_reader(0)
self.__next__()

def close(self):
self._apply('close')

def __iter__(self):
"""Generator for all frames, starting at frame 1."""
self._rewind()
"""Generator for all frames, starting at frame 0."""
self.__current_frame = -1
# start from first frame
self.__chained_trajectories_iter = self._chained_iterator()
for ts in self.__chained_trajectories_iter:
yield ts
return self

def __repr__(self):
if len(self.filenames) > 3:
fnames = "{fname} and {nfanmes} more".format(
fname=os.path.basename(self.filenames[0]),
fname=os.path.basename(self.filenames[0]),
nfanmes=len(self.filenames) - 1)
else:
fnames = ", ".join([os.path.basename(fn) for fn in self.filenames])
Expand Down Expand Up @@ -656,3 +641,14 @@ def _apply_transformations(self, ts):
# to avoid applying the same transformations multiple times on each frame

return ts

def __next__(self):
if self.__current_frame < self.n_frames - 1:
j, f = self._get_local_frame(self.__current_frame + 1)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = self.__current_frame + 1
self.__current_frame += 1
return self.ts
else:
raise StopIteration()
5 changes: 5 additions & 0 deletions testsuite/MDAnalysisTests/coordinates/test_chainreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def test_frame_numbering(self, universe):
universe.trajectory[98] # index is 0-based and frames are 0-based
assert_equal(universe.trajectory.frame, 98, "wrong frame number")

def test_next_after_frame_numbering(self, universe):
universe.trajectory[98] # index is 0-based and frames are 0-based
universe.trajectory.next()
assert_equal(universe.trajectory.frame, 99, "wrong frame number")

def test_frame(self, universe):
universe.trajectory[0]
coord0 = universe.atoms.positions.copy()
Expand Down

0 comments on commit e878e5d

Please sign in to comment.