Skip to content

Commit

Permalink
Add supports for pickle serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Jun 19, 2024
1 parent 2e053b2 commit aa6934b
Show file tree
Hide file tree
Showing 13 changed files with 90 additions and 11 deletions.
22 changes: 22 additions & 0 deletions examples/scf/25-pickle_dumps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python

'''
Serialization for a PySCF object.
'''

# Most methods can be pickled
import pickle
import pyscf

mol = pyscf.M(atom='H 0 0 0; H 0 0 1')
mf = mol.RKS(xc='pbe')
s = pickle.dumps(mf)
mf1 = pickle.loads(s)

# Dynamically generated classes cannot be serialized by the standard pickle module.
# In this case, the third party packages cloudpickle or dill support the
# dynamical classes.
import cloudpickle
mf = mol.RHF().density_fit().x2c().newton()
s = cloudpickle.dumps(mf)
mf1 = cloudpickle.loads(s)
3 changes: 3 additions & 0 deletions pyscf/cc/ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,9 @@ def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None):
self._nmo = None
self.chkfile = mf.chkfile

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', 'callback'))

@property
def ecc(self):
return self.e_corr
Expand Down
4 changes: 4 additions & 0 deletions pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(self, mol, auxbasis=None):
self._vjopt = None
self._rsh_df = {} # Range separated Coulomb DF objects

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('_cderi_to_save', '_cderi', '_vjopt', '_rsh_df'),
reset_state=True)

@property
def auxbasis(self):
return self._auxbasis
Expand Down
8 changes: 4 additions & 4 deletions pyscf/gto/mole.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,12 +1244,12 @@ def unpack(moldic):
def dumps(mol):
'''Serialize Mole object to a JSON formatted str.
'''
exclude_keys = {'output', 'stdout', '_keys',
# Constructing in function loads
'symm_orb', 'irrep_id', 'irrep_name'}
exclude_keys = {'output', 'stdout', '_keys', '_ctx_lock',
# Constructing in function loads
'symm_orb', 'irrep_id', 'irrep_name'}
# FIXME: nparray and kpts for cell objects may need to be excluded
nparray_keys = {'_atm', '_bas', '_env', '_ecpbas',
'_symm_orig', '_symm_axes'}
'_symm_orig', '_symm_axes'}

moldic = dict(mol.__dict__)
for k in exclude_keys:
Expand Down
28 changes: 27 additions & 1 deletion pyscf/lib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import itertools
import inspect
import collections
import pickle
import ctypes
import numpy
import scipy
Expand Down Expand Up @@ -548,6 +549,29 @@ def view(obj, cls):
new_obj.__dict__.update(obj.__dict__)
return new_obj

def generate_pickle_methods(excludes=(), reset_state=False):
'''Generate methods for pickle, e.g.:
class A:
__getstate__, __setstate__ = generate_pickle_methods(excludes=('a', 'b', 'c'))
'''
def getstate(obj):
dic = {**obj.__dict__}
dic.pop('stdout', None)
for key in excludes:
dic.pop(key, None)
return dic

def setstate(obj, state):
obj.stdout = sys.stdout
obj.__dict__.update(state)
for key in excludes:
setattr(obj, key, None)
if reset_state and hasattr(obj, 'reset'):
obj.reset()

return getstate, setstate


SANITY_CHECK = getattr(__config__, 'SANITY_CHECK', True)
class StreamObject:
Expand Down Expand Up @@ -670,6 +694,9 @@ def copy(self):
'''Returns a shallow copy'''
return self.view(self.__class__)

__getstate__, __setstate__ = generate_pickle_methods()


_warn_once_registry = {}
def check_sanity(obj, keysref, stdout=sys.stdout):
'''Check misinput of class attributes, check whether a class method is
Expand Down Expand Up @@ -1509,4 +1536,3 @@ def to_gpu(method, out=None):
setattr(out, key, val)
out.reset()
return out

7 changes: 7 additions & 0 deletions pyscf/lib/test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def test_isintsequence(self):
def test_prange_split(self):
self.assertEqual(list(lib.prange_split(10, 3)), [(0, 4), (4, 7), (7, 10)])

def test_pickle(self):
import pickle
from pyscf import gto
mol = gto.M()
mf = mol.GKS(xc='pbe')
pickle.loads(pickle.dumps(mf))


if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions pyscf/mcscf/mc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,6 @@ class CASSCF(casci.CASBase):
def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
casci.CASBase.__init__(self, mf_or_mol, ncas, nelecas, ncore)
self.frozen = frozen

self.callback = None
self.chkfile = self._scf.chkfile

self.fcisolver.max_cycle = getattr(__config__,
Expand All @@ -777,6 +775,9 @@ def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
self.converged = False
self._max_stepsize = None

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', 'callback'))

def dump_flags(self, verbose=None):
log = logger.new_logger(self, verbose)
log.info('')
Expand Down
4 changes: 4 additions & 0 deletions pyscf/mcscf/umc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy
import pyscf.gto
import pyscf.scf
from pyscf import lib
from pyscf.lib import logger
from pyscf.mcscf import ucasci
from pyscf.mcscf.mc1step import expmat, rotate_orb_cc, max_stepsize_scheduler, as_scanner
Expand Down Expand Up @@ -399,6 +400,9 @@ def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
self.converged = False
self._max_stepsize = None

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', 'callback'))

def dump_flags(self, verbose=None):
log = logger.new_logger(self, verbose)
log.info('')
Expand Down
3 changes: 3 additions & 0 deletions pyscf/pbc/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def __init__(self, cell, kpts=numpy.zeros((1,3))):
self._cderi = None
self._rsh_df = {} # Range separated Coulomb DF objects

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('_cderi_to_save', '_cderi', '_rsh_df'), reset_state=True)

@property
def auxbasis(self):
return self._auxbasis
Expand Down
3 changes: 3 additions & 0 deletions pyscf/pbc/df/mdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, cell, kpts=np.zeros((1,3))):
self._cderi = None
self._rsh_df = {} # Range separated Coulomb DF objects

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('_cderi_to_save', '_cderi', '_rsh_df'), reset_state=True)

def build(self, j_only=None, with_j3c=True, kpts_band=None):
df.GDF.build(self, j_only, with_j3c, kpts_band)
cell = self.cell
Expand Down
8 changes: 6 additions & 2 deletions pyscf/pbc/scf/rsjk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def __init__(self, cell, kpts=np.zeros((1,3))):
self._last_vs = (0, 0)
self._qindex = None

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('rs_cell', 'cell_d', 'supmol_sr', 'supmol_ft', 'supmol_d',
'_sr_without_dddd', '_last_vs', '_qindex'),
reset_state=True)

def has_long_range(self):
'''Whether to add the long-range part computed with AFT/FFT integrals'''
return self.omega is None or abs(self.cell.omega) < self.omega
Expand Down Expand Up @@ -123,8 +128,7 @@ def reset(self, cell=None):
self.supmol_sr = None
self.supmol_ft = None
self.supmol_d = None
self.exclude_dd_block = True
self.approx_vk_lr_missing_mo = False
self._sr_without_dddd = None
self._last_vs = (0, 0)
self._qindex = None
return self
Expand Down
3 changes: 1 addition & 2 deletions pyscf/scf/dhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,7 @@ class DHF(hf.SCF):
# corrections for small component when with_ssss is set to False
ssss_approx = getattr(__config__, 'scf_dhf_SCF_ssss_approx', 'Visscher')

_keys = {'conv_tol', 'with_ssss', 'with_gaunt',
'with_breit', 'ssss_approx'}
_keys = {'conv_tol', 'with_ssss', 'with_gaunt', 'with_breit', 'ssss_approx'}

def __init__(self, mol):
hf.SCF.__init__(self, mol)
Expand Down
3 changes: 3 additions & 0 deletions pyscf/scf/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,9 @@ def __init__(self, mol):
self._opt = {None: None}
self._eri = None # Note: self._eri requires large amount of memory

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', '_chkfile', '_opt', '_eri', 'callback'))

def check_sanity(self):
s1e = self.get_ovlp()
cond = lib.cond(s1e)
Expand Down

0 comments on commit aa6934b

Please sign in to comment.