Skip to content

Commit

Permalink
Enable pickle serialization (pyscf#2269)
Browse files Browse the repository at this point in the history
* Improve the dump_chk method for scf and mcscf classes

* Fix bug

* Add supports for pickle serialization
  • Loading branch information
sunqm authored Aug 5, 2024
1 parent ab210cd commit 1f65ec7
Show file tree
Hide file tree
Showing 15 changed files with 223 additions and 75 deletions.
22 changes: 22 additions & 0 deletions examples/scf/25-pickle_dumps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python

'''
Serialization for a PySCF object.
'''

# Most methods can be pickled
import pickle
import pyscf

mol = pyscf.M(atom='H 0 0 0; H 0 0 1')
mf = mol.RKS(xc='pbe')
s = pickle.dumps(mf)
mf1 = pickle.loads(s)

# Dynamically generated classes cannot be serialized by the standard pickle module.
# In this case, the third party packages cloudpickle or dill support the
# dynamical classes.
import cloudpickle
mf = mol.RHF().density_fit().x2c().newton()
s = cloudpickle.dumps(mf)
mf1 = cloudpickle.loads(s)
3 changes: 3 additions & 0 deletions pyscf/cc/ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,9 @@ def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None):
self._nmo = None
self.chkfile = mf.chkfile

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', 'callback'))

@property
def ecc(self):
return self.e_corr
Expand Down
4 changes: 4 additions & 0 deletions pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(self, mol, auxbasis=None):
self._vjopt = None
self._rsh_df = {} # Range separated Coulomb DF objects

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('_cderi_to_save', '_cderi', '_vjopt', '_rsh_df'),
reset_state=True)

@property
def auxbasis(self):
return self._auxbasis
Expand Down
8 changes: 4 additions & 4 deletions pyscf/gto/mole.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,12 +1244,12 @@ def unpack(moldic):
def dumps(mol):
'''Serialize Mole object to a JSON formatted str.
'''
exclude_keys = {'output', 'stdout', '_keys',
# Constructing in function loads
'symm_orb', 'irrep_id', 'irrep_name'}
exclude_keys = {'output', 'stdout', '_keys', '_ctx_lock',
# Constructing in function loads
'symm_orb', 'irrep_id', 'irrep_name'}
# FIXME: nparray and kpts for cell objects may need to be excluded
nparray_keys = {'_atm', '_bas', '_env', '_ecpbas',
'_symm_orig', '_symm_axes'}
'_symm_orig', '_symm_axes'}

moldic = dict(mol.__dict__)
for k in exclude_keys:
Expand Down
28 changes: 27 additions & 1 deletion pyscf/lib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import itertools
import inspect
import collections
import pickle
import weakref
import ctypes
import numpy
Expand Down Expand Up @@ -549,6 +550,29 @@ def view(obj, cls):
new_obj.__dict__.update(obj.__dict__)
return new_obj

def generate_pickle_methods(excludes=(), reset_state=False):
'''Generate methods for pickle, e.g.:
class A:
__getstate__, __setstate__ = generate_pickle_methods(excludes=('a', 'b', 'c'))
'''
def getstate(obj):
dic = {**obj.__dict__}
dic.pop('stdout', None)
for key in excludes:
dic.pop(key, None)
return dic

def setstate(obj, state):
obj.stdout = sys.stdout
obj.__dict__.update(state)
for key in excludes:
setattr(obj, key, None)
if reset_state and hasattr(obj, 'reset'):
obj.reset()

return getstate, setstate


SANITY_CHECK = getattr(__config__, 'SANITY_CHECK', True)
class StreamObject:
Expand Down Expand Up @@ -671,6 +695,9 @@ def copy(self):
'''Returns a shallow copy'''
return self.view(self.__class__)

__getstate__, __setstate__ = generate_pickle_methods()


_warn_once_registry = {}
def check_sanity(obj, keysref, stdout=sys.stdout):
'''Check misinput of class attributes, check whether a class method is
Expand Down Expand Up @@ -1516,4 +1543,3 @@ def to_gpu(method, out=None):
setattr(out, key, val)
out.reset()
return out

7 changes: 7 additions & 0 deletions pyscf/lib/test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def test_isintsequence(self):
def test_prange_split(self):
self.assertEqual(list(lib.prange_split(10, 3)), [(0, 4), (4, 7), (7, 10)])

def test_pickle(self):
import pickle
from pyscf import gto
mol = gto.M()
mf = mol.GKS(xc='pbe')
pickle.loads(pickle.dumps(mf))


if __name__ == "__main__":
unittest.main()
80 changes: 49 additions & 31 deletions pyscf/mcscf/mc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None,
(max_offdiag_u < casscf.small_rot_tol or casscf.small_rot_tol == 0)):
conv = True

if dump_chk:
if dump_chk and casscf.chkfile:
casscf.dump_chk(locals())

if callable(callback):
Expand Down Expand Up @@ -499,7 +499,7 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None,
'call to mc.cas_natorb_() is required')
mo_energy = None

if dump_chk:
if dump_chk and casscf.chkfile:
casscf.dump_chk(locals())

log.timer('1-step CASSCF', *cput0)
Expand Down Expand Up @@ -761,8 +761,6 @@ class CASSCF(casci.CASBase):
def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
casci.CASBase.__init__(self, mf_or_mol, ncas, nelecas, ncore)
self.frozen = frozen

self.callback = None
self.chkfile = self._scf.chkfile

self.fcisolver.max_cycle = getattr(__config__,
Expand All @@ -780,6 +778,9 @@ def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
self.converged = False
self._max_stepsize = None

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', 'callback'))

def dump_flags(self, verbose=None):
log = logger.new_logger(self, verbose)
log.info('')
Expand Down Expand Up @@ -1177,38 +1178,55 @@ def _exact_paaa(self, mo, u, out=None):
paaa = lib.transpose(buf.reshape(ncas*ncas,-1), out=out)
return paaa.reshape(nmo,ncas,ncas,ncas)

def dump_chk(self, envs):
if not self.chkfile:
return self
def dump_chk(self, envs_or_file):
'''Serialize the MCSCF object and save it to the specified chkfile.
if getattr(self.fcisolver, 'nevpt_intermediate', None):
civec = None
elif self.chk_ci:
civec = envs['fcivec']
Args:
envs_or_file:
If this argument is a file path, the serialized MCSCF object is
saved to the file specified by this argument.
If this attribute is a dict (created by locals()), the necessary
variables are saved to the file specified by the attribute .chkfile.
'''
if isinstance(envs_or_file, str):
envs = None
chk_file = envs_or_file
else:
civec = None
envs = envs_or_file
chk_file = self.chkfile
if not chk_file:
return self

e_tot = mo_coeff = mo_occ = mo_energy = e_cas = civec = casdm1 = None
ncore = self.ncore
nocc = ncore + self.ncas
if 'mo' in envs:
mo_coeff = envs['mo']
else:
mo_coeff = envs['mo_coeff']
mo_occ = numpy.zeros(mo_coeff.shape[1])
mo_occ[:ncore] = 2
if self.natorb:
occ = self._eig(-envs['casdm1'], ncore, nocc)[0]
mo_occ[ncore:nocc] = -occ
else:
mo_occ[ncore:nocc] = envs['casdm1'].diagonal()
# Note: mo_energy in active space =/= F_{ii} (F is general Fock)
if 'mo_energy' in envs:
mo_energy = envs['mo_energy']
else:
mo_energy = 'None'
chkfile.dump_mcscf(self, self.chkfile, 'mcscf', envs['e_tot'],

if envs is not None:
if self.chk_ci:
civec = envs.get('fcivec', None)

e_tot = envs['e_tot']
e_cas = envs['e_cas']
casdm1 = envs['casdm1']
if 'mo' in envs:
mo_coeff = envs['mo']
else:
mo_coeff = envs['mo_coeff']
mo_occ = numpy.zeros(mo_coeff.shape[1])
mo_occ[:ncore] = 2
if self.natorb:
occ = self._eig(-casdm1, ncore, nocc)[0]
mo_occ[ncore:nocc] = -occ
else:
mo_occ[ncore:nocc] = casdm1.diagonal()
# Note: mo_energy in active space =/= F_{ii} (F is general Fock)
if 'mo_energy' in envs:
mo_energy = envs['mo_energy']

chkfile.dump_mcscf(self, chk_file, 'mcscf', e_tot,
mo_coeff, ncore, self.ncas, mo_occ,
mo_energy, envs['e_cas'], civec, envs['casdm1'],
overwrite_mol=False)
mo_energy, e_cas, civec, casdm1,
overwrite_mol=(envs is None))
return self

def update_from_chk(self, chkfile=None):
Expand Down
69 changes: 44 additions & 25 deletions pyscf/mcscf/umc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy
import pyscf.gto
import pyscf.scf
from pyscf import lib
from pyscf.lib import logger
from pyscf.mcscf import ucasci
from pyscf.mcscf.mc1step import expmat, rotate_orb_cc, max_stepsize_scheduler, as_scanner
Expand Down Expand Up @@ -399,6 +400,9 @@ def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None):
self.converged = False
self._max_stepsize = None

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('chkfile', 'callback'))

def dump_flags(self, verbose=None):
log = logger.new_logger(self, verbose)
log.info('')
Expand Down Expand Up @@ -732,39 +736,54 @@ def solve_approx_ci(self, h1, h2, ci0, ecore, e_cas):
ci1 += xs[i] * v[i,0]
return ci1, g

def dump_chk(self, envs):
if not self.chkfile:
return self
def dump_chk(self, envs_or_file):
'''Serialize the MCSCF object and save it to the specified chkfile.
if self.chk_ci:
civec = envs['fcivec']
Args:
envs_or_file:
If this argument is a file path, the serialized MCSCF object is
saved to the file specified by this argument.
If this attribute is a dict (created by locals()), the necessary
variables are saved to the file specified by the attribute .chkfile.
'''
if isinstance(envs_or_file, str):
envs = None
chk_file = envs_or_file
else:
civec = None
envs = envs_or_file
chk_file = self.chkfile
if not chk_file:
return self

e_tot = mo_coeff = mo_occ = mo_energy = e_cas = civec = casdm1 = None
ncore = self.ncore
ncas = self.ncas
nocca = ncore[0] + ncas
noccb = ncore[1] + ncas
if 'mo' in envs:
mo_coeff = envs['mo']
else:
mo_coeff = envs['mo']
mo_occ = numpy.zeros((2,envs['mo'][0].shape[1]))
mo_occ[0,:ncore[0]] = 1
mo_occ[1,:ncore[1]] = 1
if self.natorb:
occa, ucas = self._eig(-envs['casdm1'][0], ncore[0], nocca)
occb, ucas = self._eig(-envs['casdm1'][1], ncore[1], noccb)
mo_occ[0,ncore[0]:nocca] = -occa
mo_occ[1,ncore[1]:noccb] = -occb
else:
mo_occ[0,ncore[0]:nocca] = envs['casdm1'][0].diagonal()
mo_occ[1,ncore[1]:noccb] = envs['casdm1'][1].diagonal()
mo_energy = 'None'

chkfile.dump_mcscf(self, self.chkfile, 'mcscf', envs['e_tot'],
if envs is not None:
if self.chk_ci:
civec = envs['fcivec']
if 'mo' in envs:
mo_coeff = envs['mo']
else:
mo_coeff = envs['mo_coeff']
mo_occ = numpy.zeros((2,envs['mo'][0].shape[1]))
mo_occ[0,:ncore[0]] = 1
mo_occ[1,:ncore[1]] = 1
if self.natorb:
occa, ucas = self._eig(-casdm1[0], ncore[0], nocca)
occb, ucas = self._eig(-casdm1[1], ncore[1], noccb)
mo_occ[0,ncore[0]:nocca] = -occa
mo_occ[1,ncore[1]:noccb] = -occb
else:
mo_occ[0,ncore[0]:nocca] = casdm1[0].diagonal()
mo_occ[1,ncore[1]:noccb] = casdm1[1].diagonal()

chkfile.dump_mcscf(self, self.chkfile, 'mcscf', e_tot,
mo_coeff, ncore, ncas, mo_occ,
mo_energy, envs['e_cas'], civec, envs['casdm1'],
overwrite_mol=False)
mo_energy, e_cas, civec, casdm1,
overwrite_mol=(envs is None))
return self

def rotate_mo(self, mo, u, log=None):
Expand Down
3 changes: 3 additions & 0 deletions pyscf/pbc/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def __init__(self, cell, kpts=numpy.zeros((1,3))):
self._cderi = None
self._rsh_df = {} # Range separated Coulomb DF objects

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('_cderi_to_save', '_cderi', '_rsh_df'), reset_state=True)

@property
def auxbasis(self):
return self._auxbasis
Expand Down
3 changes: 3 additions & 0 deletions pyscf/pbc/df/mdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, cell, kpts=np.zeros((1,3))):
self._cderi = None
self._rsh_df = {} # Range separated Coulomb DF objects

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('_cderi_to_save', '_cderi', '_rsh_df'), reset_state=True)

def build(self, j_only=None, with_j3c=True, kpts_band=None):
df.GDF.build(self, j_only, with_j3c, kpts_band)
cell = self.cell
Expand Down
18 changes: 15 additions & 3 deletions pyscf/pbc/scf/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,21 @@ def init_guess_by_chkfile(self, chk=None, project=None, kpt=None):
def from_chk(self, chk=None, project=None, kpt=None):
return self.init_guess_by_chkfile(chk, project, kpt)

def dump_chk(self, envs):
if self.chkfile:
mol_hf.SCF.dump_chk(self, envs)
def dump_chk(self, envs_or_file):
'''Serialize the SCF object and save it to the specified chkfile.
Args:
envs_or_file:
If this argument is a file path, the serialized SCF object is
saved to the file specified by this argument.
If this attribute is a dict (created by locals()), the necessary
variables are saved to the file specified by the attribute mf.chkfile.
'''
mol_hf.SCF.dump_chk(self, envs_or_file)
if isinstance(envs_or_file, str):
with lib.H5FileWrap(envs_or_file, 'a') as fh5:
fh5['scf/kpt'] = self.kpt
elif self.chkfile:
with lib.H5FileWrap(self.chkfile, 'a') as fh5:
fh5['scf/kpt'] = self.kpt
return self
Expand Down
Loading

0 comments on commit 1f65ec7

Please sign in to comment.