diff --git a/package/CHANGELOG b/package/CHANGELOG index fa139c024a5..bf9e895a954 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -108,6 +108,7 @@ Fixes * Fix tests for analysis.bat that could fail when run in parallel and that would create a test artifact (Issue #2979, PR #2981) * Fix syntax warning over comparison of literals using is (Issue #3066) + * new `Results` class now can be pickled/unpickled. (PR #3309) Enhancements * Added guessers for aromaticity and Gasteiger partial charges (Issue #2468, diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py index f38a24890eb..c7bb9a40293 100644 --- a/package/MDAnalysis/analysis/base.py +++ b/package/MDAnalysis/analysis/base.py @@ -62,10 +62,6 @@ class in `scikit-learn`_. If a key is not of type ``str`` and therefore is not able to be accessed by attribute. - Notes - ----- - Pickling of ``Results`` is currently not supported. - Examples -------- >>> from MDAnalysis.analysis.base import Results @@ -103,7 +99,11 @@ def __setitem__(self, key, item): self._validate_key(key) super().__setitem__(key, item) - __setattr__ = __setitem__ + def __setattr__(self, attr, val): + if attr == 'data': + super().__setattr__(attr, val) + else: + self.__setitem__(attr, val) def __getattr__(self, attr): try: @@ -119,6 +119,12 @@ def __delattr__(self, attr): raise AttributeError("'Results' object has no " f"attribute '{attr}'") from err + def __getstate__(self): + return self.data + + def __setstate__(self, state): + self.data = state + class AnalysisBase(object): r"""Base class for defining multi-frame analysis diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py index 9c4995e4c07..8201a93e084 100644 --- a/testsuite/MDAnalysisTests/analysis/test_base.py +++ b/testsuite/MDAnalysisTests/analysis/test_base.py @@ -21,6 +21,7 @@ # J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 # from collections import UserDict +import pickle import pytest @@ -125,6 +126,10 @@ def test_update_data_fail(self, results): with pytest.raises(AttributeError, match=msg): results.update({"data": 0}) + def test_pickle(self, results): + results_p = pickle.dumps(results) + results_new = pickle.loads(results_p) + @pytest.mark.parametrize("args, kwargs, length", [ (({"darth": "tater"},), {}, 1), ([], {"darth": "tater"}, 1), diff --git a/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py index 0832e907756..12e63ba31b5 100644 --- a/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py +++ b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py @@ -29,6 +29,7 @@ import MDAnalysis as mda from MDAnalysis.coordinates.core import get_reader_for +from MDAnalysis.analysis.rms import RMSD from MDAnalysisTests.datafiles import ( CRD, @@ -177,3 +178,11 @@ def test_readers_pickle(ref_reader): # single frame files pass assert_equal(reanimated.ts, ref_reader.ts) + + +def test_analysis_pickle(): + u = mda.Universe(PSF, DCD) + rmsd = RMSD(u.atoms, u.atoms) + rmsd.run() + rmsd_p = pickle.dumps(rmsd) + rmsd_new = pickle.loads(rmsd_p)