Skip to content

Commit

Permalink
No need to redirect to the original frame explicitly during deseriali…
Browse files Browse the repository at this point in the history
…zation (#3722)
  • Loading branch information
yuxuanzhuang authored Jul 4, 2024
1 parent cfa4438 commit cfda8b7
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 25 deletions.
5 changes: 5 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ The rules for this file:
* 2.8.0

Fixes
* Fix failure in double-serialization of TextIOPicklable file reader.
(Issue #3723, PR #3722)
* Fix failure to preserve modification of coordinates after serialization,
e.g. with transformations
(Issue #4633, PR #3722)
* Fix PSFParser error when encoutering string-like resids
* (Issue #2053, Issue #4189 PR #4582)
* Fix `MDAnalysis.analysis.align.AlignTraj` not accepting writer kwargs
Expand Down
1 change: 0 additions & 1 deletion package/MDAnalysis/coordinates/H5MD.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,6 @@ def __setstate__(self, state):
self.__dict__ = state
self._particle_group = self._file['particles'][
list(self._file['particles'])[0]]
self[self.ts.frame]


class H5MDWriter(base.WriterBase):
Expand Down
12 changes: 9 additions & 3 deletions package/MDAnalysis/coordinates/TNG.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,15 @@ def __setstate__(self, state):
self.__dict__ = state
# reconstruct file iterator
self._file_iterator = pytng.TNGFileIterator(self.filename, "r")
# make sure we re-read the current frame to update C level objects in
# the file iterator
self._read_frame(self._frame)

# unlike self._read_frame(self._frame),
# the following lines update the state of the C-level file iterator
# without updating the ts object.
# This is necessary to preserve the modification,
# e.g. changing coordinates, in the ts object.
# see PR #3722 for more details.
step = self._frame_to_step(self._frame)
_ = self._file_iterator.read_step(step)

def Writer(self):
"""Writer for TNG files
Expand Down
7 changes: 3 additions & 4 deletions package/MDAnalysis/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,9 @@ class ProtoReader(IOBase, metaclass=_Readermeta):
.. versionchanged:: 2.0.0
Now supports (un)pickle. Upon unpickling,
the current timestep is retained by reconstrunction.
.. versionchanged:: 2.8.0
the modification of coordinates was preserved
after serialization.
"""

#: The appropriate Timestep class, e.g.
Expand Down Expand Up @@ -1442,10 +1445,6 @@ def _apply_transformations(self, ts):

return ts

def __setstate__(self, state):
self.__dict__ = state
self[self.ts.frame]


class ReaderBase(ProtoReader):
"""Base class for trajectory readers that extends :class:`ProtoReader` with a
Expand Down
44 changes: 29 additions & 15 deletions package/MDAnalysis/lib/picklable_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ def __init__(self, name, mode='r'):

def __setstate__(self, state):
name = state["name_val"]
super().__init__(name, mode='r')
self.__init__(name, mode='r')
try:
self.seek(state["tell_val"])
except KeyError:
pass


def __reduce_ex__(self, prot):
if self._mode != 'r':
raise RuntimeError("Can only pickle files that were opened "
Expand Down Expand Up @@ -165,7 +166,7 @@ def __setstate__(self, state):
raw_class = state["raw_class"]
name = state["name_val"]
raw = raw_class(name)
super().__init__(raw)
self.__init__(raw)
self.seek(state["tell_val"])

def __reduce_ex__(self, prot):
Expand All @@ -177,18 +178,13 @@ def __reduce_ex__(self, prot):
"name_val": self.name,
"tell_val": self.tell()})


class TextIOPicklable(io.TextIOWrapper):
"""Character and line based picklable file-like object.
This class provides a file-like :class:`io.TextIOWrapper` object that can
be pickled. Note that this only works in read mode.
Note
----
After pickling, the current position is reset. `universe.trajectory[i]` has
to be used to return to its original frame.
Parameters
----------
raw : FileIO object
Expand All @@ -207,21 +203,34 @@ class TextIOPicklable(io.TextIOWrapper):
.. versionadded:: 2.0.0
.. versionchanged:: 2.8.0
The raw class instance instead of the class name
that is wrapped inside will be serialized.
After deserialization, the current position is no longer reset
so `universe.trajectory[i]` is not needed to seek to the
original position.
"""
def __init__(self, raw):
super().__init__(raw)
self.raw_class = raw.__class__


def __setstate__(self, args):
raw_class = args["raw_class"]
name = args["name_val"]
tell = args["tell_val"]
# raw_class is used for further expansion this functionality to
# Gzip files, which also requires a text wrapper.
raw = raw_class(name)
super().__init__(raw)
self.__init__(raw)
if tell is not None:
self.seek(tell)

def __reduce_ex__(self, prot):
try:
curr_loc = self.tell()
# some readers (e.g. GMS) disable tell() due to using next()
except OSError:
curr_loc = None
try:
name = self.name
except AttributeError:
Expand All @@ -230,7 +239,8 @@ def __reduce_ex__(self, prot):
return (self.__class__.__new__,
(self.__class__,),
{"raw_class": self.raw_class,
"name_val": name})
"name_val": name,
"tell_val": curr_loc})


class BZ2Picklable(bz2.BZ2File):
Expand Down Expand Up @@ -293,9 +303,11 @@ def __getstate__(self):
return {"name_val": self._fp.name, "tell_val": self.tell()}

def __setstate__(self, args):
super().__init__(args["name_val"])
name = args["name_val"]
tell = args["tell_val"]
self.__init__(name)
try:
self.seek(args["tell_val"])
self.seek(tell)
except KeyError:
pass

Expand Down Expand Up @@ -361,9 +373,11 @@ def __getstate__(self):
"tell_val": self.tell()}

def __setstate__(self, args):
super().__init__(args["name_val"])
name = args["name_val"]
tell = args["tell_val"]
self.__init__(name)
try:
self.seek(args["tell_val"])
self.seek(tell)
except KeyError:
pass

Expand Down
19 changes: 19 additions & 0 deletions testsuite/MDAnalysisTests/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,17 @@ def test_last_slice(self):
def test_pickle_singleframe_reader(self):
reader = self.universe.trajectory
reader_p = pickle.loads(pickle.dumps(reader))
reader_p_p = pickle.loads(pickle.dumps(reader_p))
assert_equal(len(reader), len(reader_p))
assert_equal(reader.ts, reader_p.ts,
"Single-frame timestep is changed after pickling")
assert_equal(len(reader), len(reader_p_p))
assert_equal(reader.ts, reader_p_p.ts,
"Single-frame timestep is changed after double pickling")
reader.ts.positions[0] = np.array([1, 2, 3])
reader_p = pickle.loads(pickle.dumps(reader))
assert_equal(reader.ts, reader_p.ts,
"Modification of ts not preserved after serialization")


class BaseReference(object):
Expand Down Expand Up @@ -443,6 +451,14 @@ def test_pickle_reader(self, reader):
assert_equal(len(reader), len(reader_p))
assert_equal(reader.ts, reader_p.ts,
"Timestep is changed after pickling")
reader_p_p = pickle.loads(pickle.dumps(reader_p))
assert_equal(len(reader), len(reader_p_p))
assert_equal(reader.ts, reader_p_p.ts,
"Timestep is changed after double pickling")
reader.ts.positions[0] = np.array([1, 2, 3])
reader_p = pickle.loads(pickle.dumps(reader))
assert_equal(reader.ts, reader_p.ts,
"Modification of ts not preserved after serialization")

def test_frame_collect_all_same(self, reader):
# check that the timestep resets so that the base reference is the same
Expand Down Expand Up @@ -604,6 +620,9 @@ def test_pickle_next_ts_reader(self, reader):
reader_p = pickle.loads(pickle.dumps(reader))
assert_equal(next(reader), next(reader_p),
"Next timestep is changed after pickling")
reader_p_p = pickle.loads(pickle.dumps(reader_p))
assert_equal(next(reader), next(reader_p_p),
"Next timestep is changed after double pickling")

# To make sure pickle works for last frame.
def test_pickle_last_ts_reader(self, reader):
Expand Down
4 changes: 2 additions & 2 deletions testsuite/MDAnalysisTests/utils/test_pickleio.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def test_iopickle_text(f_text):
assert_equal(f_text.readlines(), f_text_pickled.readlines())


def test_offset_text_to_0(f_text):
def test_offset_text_same(f_text):
f_text.readline()
f_text_pickled = pickle.loads(pickle.dumps(f_text))
assert_equal(f_text_pickled.tell(), 0)
assert_equal(f_text_pickled.tell(), f_text.tell())


@pytest.fixture(params=[
Expand Down

0 comments on commit cfda8b7

Please sign in to comment.