diff --git a/package/CHANGELOG b/package/CHANGELOG index 876a0541de2..59afa87003b 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -22,6 +22,11 @@ The rules for this file: * 2.8.0 Fixes + * Fix failure in double-serialization of TextIOPicklable file reader. + (Issue #3723, PR #3722) + * Fix failure to preserve modification of coordinates after serialization, + e.g. with transformations + (Issue #4633, PR #3722) * Fix PSFParser error when encoutering string-like resids * (Issue #2053, Issue #4189 PR #4582) * Fix `MDAnalysis.analysis.align.AlignTraj` not accepting writer kwargs diff --git a/package/MDAnalysis/coordinates/H5MD.py b/package/MDAnalysis/coordinates/H5MD.py index b8264f0339e..1062240b48c 100644 --- a/package/MDAnalysis/coordinates/H5MD.py +++ b/package/MDAnalysis/coordinates/H5MD.py @@ -828,7 +828,6 @@ def __setstate__(self, state): self.__dict__ = state self._particle_group = self._file['particles'][ list(self._file['particles'])[0]] - self[self.ts.frame] class H5MDWriter(base.WriterBase): diff --git a/package/MDAnalysis/coordinates/TNG.py b/package/MDAnalysis/coordinates/TNG.py index 3a037a5537f..7a44be3518b 100644 --- a/package/MDAnalysis/coordinates/TNG.py +++ b/package/MDAnalysis/coordinates/TNG.py @@ -499,9 +499,15 @@ def __setstate__(self, state): self.__dict__ = state # reconstruct file iterator self._file_iterator = pytng.TNGFileIterator(self.filename, "r") - # make sure we re-read the current frame to update C level objects in - # the file iterator - self._read_frame(self._frame) + + # unlike self._read_frame(self._frame), + # the following lines update the state of the C-level file iterator + # without updating the ts object. + # This is necessary to preserve the modification, + # e.g. changing coordinates, in the ts object. + # see PR #3722 for more details. + step = self._frame_to_step(self._frame) + _ = self._file_iterator.read_step(step) def Writer(self): """Writer for TNG files diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index 45befe1c552..dda4a61a7ce 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -659,6 +659,9 @@ class ProtoReader(IOBase, metaclass=_Readermeta): .. versionchanged:: 2.0.0 Now supports (un)pickle. Upon unpickling, the current timestep is retained by reconstrunction. + .. versionchanged:: 2.8.0 + the modification of coordinates was preserved + after serialization. """ #: The appropriate Timestep class, e.g. @@ -1442,10 +1445,6 @@ def _apply_transformations(self, ts): return ts - def __setstate__(self, state): - self.__dict__ = state - self[self.ts.frame] - class ReaderBase(ProtoReader): """Base class for trajectory readers that extends :class:`ProtoReader` with a diff --git a/package/MDAnalysis/lib/picklable_file_io.py b/package/MDAnalysis/lib/picklable_file_io.py index 2425e9d458a..91413619c5b 100644 --- a/package/MDAnalysis/lib/picklable_file_io.py +++ b/package/MDAnalysis/lib/picklable_file_io.py @@ -114,12 +114,13 @@ def __init__(self, name, mode='r'): def __setstate__(self, state): name = state["name_val"] - super().__init__(name, mode='r') + self.__init__(name, mode='r') try: self.seek(state["tell_val"]) except KeyError: pass + def __reduce_ex__(self, prot): if self._mode != 'r': raise RuntimeError("Can only pickle files that were opened " @@ -165,7 +166,7 @@ def __setstate__(self, state): raw_class = state["raw_class"] name = state["name_val"] raw = raw_class(name) - super().__init__(raw) + self.__init__(raw) self.seek(state["tell_val"]) def __reduce_ex__(self, prot): @@ -177,18 +178,13 @@ def __reduce_ex__(self, prot): "name_val": self.name, "tell_val": self.tell()}) + class TextIOPicklable(io.TextIOWrapper): """Character and line based picklable file-like object. This class provides a file-like :class:`io.TextIOWrapper` object that can be pickled. Note that this only works in read mode. - Note - ---- - After pickling, the current position is reset. `universe.trajectory[i]` has - to be used to return to its original frame. - - Parameters ---------- raw : FileIO object @@ -207,21 +203,34 @@ class TextIOPicklable(io.TextIOWrapper): .. versionadded:: 2.0.0 + .. versionchanged:: 2.8.0 + The raw class instance instead of the class name + that is wrapped inside will be serialized. + After deserialization, the current position is no longer reset + so `universe.trajectory[i]` is not needed to seek to the + original position. """ def __init__(self, raw): super().__init__(raw) self.raw_class = raw.__class__ - def __setstate__(self, args): raw_class = args["raw_class"] name = args["name_val"] + tell = args["tell_val"] # raw_class is used for further expansion this functionality to # Gzip files, which also requires a text wrapper. raw = raw_class(name) - super().__init__(raw) + self.__init__(raw) + if tell is not None: + self.seek(tell) def __reduce_ex__(self, prot): + try: + curr_loc = self.tell() + # some readers (e.g. GMS) disable tell() due to using next() + except OSError: + curr_loc = None try: name = self.name except AttributeError: @@ -230,7 +239,8 @@ def __reduce_ex__(self, prot): return (self.__class__.__new__, (self.__class__,), {"raw_class": self.raw_class, - "name_val": name}) + "name_val": name, + "tell_val": curr_loc}) class BZ2Picklable(bz2.BZ2File): @@ -293,9 +303,11 @@ def __getstate__(self): return {"name_val": self._fp.name, "tell_val": self.tell()} def __setstate__(self, args): - super().__init__(args["name_val"]) + name = args["name_val"] + tell = args["tell_val"] + self.__init__(name) try: - self.seek(args["tell_val"]) + self.seek(tell) except KeyError: pass @@ -361,9 +373,11 @@ def __getstate__(self): "tell_val": self.tell()} def __setstate__(self, args): - super().__init__(args["name_val"]) + name = args["name_val"] + tell = args["tell_val"] + self.__init__(name) try: - self.seek(args["tell_val"]) + self.seek(tell) except KeyError: pass diff --git a/testsuite/MDAnalysisTests/coordinates/base.py b/testsuite/MDAnalysisTests/coordinates/base.py index a2a131bbbde..3de8cfb9ff6 100644 --- a/testsuite/MDAnalysisTests/coordinates/base.py +++ b/testsuite/MDAnalysisTests/coordinates/base.py @@ -122,9 +122,17 @@ def test_last_slice(self): def test_pickle_singleframe_reader(self): reader = self.universe.trajectory reader_p = pickle.loads(pickle.dumps(reader)) + reader_p_p = pickle.loads(pickle.dumps(reader_p)) assert_equal(len(reader), len(reader_p)) assert_equal(reader.ts, reader_p.ts, "Single-frame timestep is changed after pickling") + assert_equal(len(reader), len(reader_p_p)) + assert_equal(reader.ts, reader_p_p.ts, + "Single-frame timestep is changed after double pickling") + reader.ts.positions[0] = np.array([1, 2, 3]) + reader_p = pickle.loads(pickle.dumps(reader)) + assert_equal(reader.ts, reader_p.ts, + "Modification of ts not preserved after serialization") class BaseReference(object): @@ -443,6 +451,14 @@ def test_pickle_reader(self, reader): assert_equal(len(reader), len(reader_p)) assert_equal(reader.ts, reader_p.ts, "Timestep is changed after pickling") + reader_p_p = pickle.loads(pickle.dumps(reader_p)) + assert_equal(len(reader), len(reader_p_p)) + assert_equal(reader.ts, reader_p_p.ts, + "Timestep is changed after double pickling") + reader.ts.positions[0] = np.array([1, 2, 3]) + reader_p = pickle.loads(pickle.dumps(reader)) + assert_equal(reader.ts, reader_p.ts, + "Modification of ts not preserved after serialization") def test_frame_collect_all_same(self, reader): # check that the timestep resets so that the base reference is the same @@ -604,6 +620,9 @@ def test_pickle_next_ts_reader(self, reader): reader_p = pickle.loads(pickle.dumps(reader)) assert_equal(next(reader), next(reader_p), "Next timestep is changed after pickling") + reader_p_p = pickle.loads(pickle.dumps(reader_p)) + assert_equal(next(reader), next(reader_p_p), + "Next timestep is changed after double pickling") # To make sure pickle works for last frame. def test_pickle_last_ts_reader(self, reader): diff --git a/testsuite/MDAnalysisTests/utils/test_pickleio.py b/testsuite/MDAnalysisTests/utils/test_pickleio.py index d26de0aa258..824261ed218 100644 --- a/testsuite/MDAnalysisTests/utils/test_pickleio.py +++ b/testsuite/MDAnalysisTests/utils/test_pickleio.py @@ -90,10 +90,10 @@ def test_iopickle_text(f_text): assert_equal(f_text.readlines(), f_text_pickled.readlines()) -def test_offset_text_to_0(f_text): +def test_offset_text_same(f_text): f_text.readline() f_text_pickled = pickle.loads(pickle.dumps(f_text)) - assert_equal(f_text_pickled.tell(), 0) + assert_equal(f_text_pickled.tell(), f_text.tell()) @pytest.fixture(params=[