From 1f65ec7a6df708aeaf1823e620ae770cdac5f9b6 Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Mon, 5 Aug 2024 08:07:01 -0700 Subject: [PATCH] Enable pickle serialization (#2269) * Improve the dump_chk method for scf and mcscf classes * Fix bug * Add supports for pickle serialization --- examples/scf/25-pickle_dumps.py | 22 +++++++++ pyscf/cc/ccsd.py | 3 ++ pyscf/df/df.py | 4 ++ pyscf/gto/mole.py | 8 ++-- pyscf/lib/misc.py | 28 +++++++++++- pyscf/lib/test/test_misc.py | 7 +++ pyscf/mcscf/mc1step.py | 80 ++++++++++++++++++++------------- pyscf/mcscf/umc1step.py | 69 +++++++++++++++++----------- pyscf/pbc/df/df.py | 3 ++ pyscf/pbc/df/mdf.py | 3 ++ pyscf/pbc/scf/hf.py | 18 ++++++-- pyscf/pbc/scf/khf.py | 18 ++++++-- pyscf/pbc/scf/rsjk.py | 8 +++- pyscf/scf/dhf.py | 3 +- pyscf/scf/hf.py | 24 ++++++++-- 15 files changed, 223 insertions(+), 75 deletions(-) create mode 100644 examples/scf/25-pickle_dumps.py diff --git a/examples/scf/25-pickle_dumps.py b/examples/scf/25-pickle_dumps.py new file mode 100644 index 0000000000..4aa517afde --- /dev/null +++ b/examples/scf/25-pickle_dumps.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +''' +Serialization for a PySCF object. +''' + +# Most methods can be pickled +import pickle +import pyscf + +mol = pyscf.M(atom='H 0 0 0; H 0 0 1') +mf = mol.RKS(xc='pbe') +s = pickle.dumps(mf) +mf1 = pickle.loads(s) + +# Dynamically generated classes cannot be serialized by the standard pickle module. +# In this case, the third party packages cloudpickle or dill support the +# dynamical classes. +import cloudpickle +mf = mol.RHF().density_fit().x2c().newton() +s = cloudpickle.dumps(mf) +mf1 = cloudpickle.loads(s) diff --git a/pyscf/cc/ccsd.py b/pyscf/cc/ccsd.py index bbbba59221..0468e269a7 100644 --- a/pyscf/cc/ccsd.py +++ b/pyscf/cc/ccsd.py @@ -971,6 +971,9 @@ def __init__(self, mf, frozen=None, mo_coeff=None, mo_occ=None): self._nmo = None self.chkfile = mf.chkfile + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('chkfile', 'callback')) + @property def ecc(self): return self.e_corr diff --git a/pyscf/df/df.py b/pyscf/df/df.py index 4eb735c41f..3aea0ed81b 100644 --- a/pyscf/df/df.py +++ b/pyscf/df/df.py @@ -98,6 +98,10 @@ def __init__(self, mol, auxbasis=None): self._vjopt = None self._rsh_df = {} # Range separated Coulomb DF objects + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('_cderi_to_save', '_cderi', '_vjopt', '_rsh_df'), + reset_state=True) + @property def auxbasis(self): return self._auxbasis diff --git a/pyscf/gto/mole.py b/pyscf/gto/mole.py index d1d85c7fee..0f2ec506f9 100644 --- a/pyscf/gto/mole.py +++ b/pyscf/gto/mole.py @@ -1244,12 +1244,12 @@ def unpack(moldic): def dumps(mol): '''Serialize Mole object to a JSON formatted str. ''' - exclude_keys = {'output', 'stdout', '_keys', - # Constructing in function loads - 'symm_orb', 'irrep_id', 'irrep_name'} + exclude_keys = {'output', 'stdout', '_keys', '_ctx_lock', + # Constructing in function loads + 'symm_orb', 'irrep_id', 'irrep_name'} # FIXME: nparray and kpts for cell objects may need to be excluded nparray_keys = {'_atm', '_bas', '_env', '_ecpbas', - '_symm_orig', '_symm_axes'} + '_symm_orig', '_symm_axes'} moldic = dict(mol.__dict__) for k in exclude_keys: diff --git a/pyscf/lib/misc.py b/pyscf/lib/misc.py index 679a2b169b..77825272b0 100644 --- a/pyscf/lib/misc.py +++ b/pyscf/lib/misc.py @@ -31,6 +31,7 @@ import itertools import inspect import collections +import pickle import weakref import ctypes import numpy @@ -549,6 +550,29 @@ def view(obj, cls): new_obj.__dict__.update(obj.__dict__) return new_obj +def generate_pickle_methods(excludes=(), reset_state=False): + '''Generate methods for pickle, e.g.: + + class A: + __getstate__, __setstate__ = generate_pickle_methods(excludes=('a', 'b', 'c')) + ''' + def getstate(obj): + dic = {**obj.__dict__} + dic.pop('stdout', None) + for key in excludes: + dic.pop(key, None) + return dic + + def setstate(obj, state): + obj.stdout = sys.stdout + obj.__dict__.update(state) + for key in excludes: + setattr(obj, key, None) + if reset_state and hasattr(obj, 'reset'): + obj.reset() + + return getstate, setstate + SANITY_CHECK = getattr(__config__, 'SANITY_CHECK', True) class StreamObject: @@ -671,6 +695,9 @@ def copy(self): '''Returns a shallow copy''' return self.view(self.__class__) + __getstate__, __setstate__ = generate_pickle_methods() + + _warn_once_registry = {} def check_sanity(obj, keysref, stdout=sys.stdout): '''Check misinput of class attributes, check whether a class method is @@ -1516,4 +1543,3 @@ def to_gpu(method, out=None): setattr(out, key, val) out.reset() return out - diff --git a/pyscf/lib/test/test_misc.py b/pyscf/lib/test/test_misc.py index 21a4a5cbf7..23030466bb 100644 --- a/pyscf/lib/test/test_misc.py +++ b/pyscf/lib/test/test_misc.py @@ -87,6 +87,13 @@ def test_isintsequence(self): def test_prange_split(self): self.assertEqual(list(lib.prange_split(10, 3)), [(0, 4), (4, 7), (7, 10)]) + def test_pickle(self): + import pickle + from pyscf import gto + mol = gto.M() + mf = mol.GKS(xc='pbe') + pickle.loads(pickle.dumps(mf)) + if __name__ == "__main__": unittest.main() diff --git a/pyscf/mcscf/mc1step.py b/pyscf/mcscf/mc1step.py index 2525defa48..3ea8199fe4 100644 --- a/pyscf/mcscf/mc1step.py +++ b/pyscf/mcscf/mc1step.py @@ -467,7 +467,7 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None, (max_offdiag_u < casscf.small_rot_tol or casscf.small_rot_tol == 0)): conv = True - if dump_chk: + if dump_chk and casscf.chkfile: casscf.dump_chk(locals()) if callable(callback): @@ -499,7 +499,7 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None, 'call to mc.cas_natorb_() is required') mo_energy = None - if dump_chk: + if dump_chk and casscf.chkfile: casscf.dump_chk(locals()) log.timer('1-step CASSCF', *cput0) @@ -761,8 +761,6 @@ class CASSCF(casci.CASBase): def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None): casci.CASBase.__init__(self, mf_or_mol, ncas, nelecas, ncore) self.frozen = frozen - - self.callback = None self.chkfile = self._scf.chkfile self.fcisolver.max_cycle = getattr(__config__, @@ -780,6 +778,9 @@ def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None): self.converged = False self._max_stepsize = None + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('chkfile', 'callback')) + def dump_flags(self, verbose=None): log = logger.new_logger(self, verbose) log.info('') @@ -1177,38 +1178,55 @@ def _exact_paaa(self, mo, u, out=None): paaa = lib.transpose(buf.reshape(ncas*ncas,-1), out=out) return paaa.reshape(nmo,ncas,ncas,ncas) - def dump_chk(self, envs): - if not self.chkfile: - return self + def dump_chk(self, envs_or_file): + '''Serialize the MCSCF object and save it to the specified chkfile. - if getattr(self.fcisolver, 'nevpt_intermediate', None): - civec = None - elif self.chk_ci: - civec = envs['fcivec'] + Args: + envs_or_file: + If this argument is a file path, the serialized MCSCF object is + saved to the file specified by this argument. + If this attribute is a dict (created by locals()), the necessary + variables are saved to the file specified by the attribute .chkfile. + ''' + if isinstance(envs_or_file, str): + envs = None + chk_file = envs_or_file else: - civec = None + envs = envs_or_file + chk_file = self.chkfile + if not chk_file: + return self + + e_tot = mo_coeff = mo_occ = mo_energy = e_cas = civec = casdm1 = None ncore = self.ncore nocc = ncore + self.ncas - if 'mo' in envs: - mo_coeff = envs['mo'] - else: - mo_coeff = envs['mo_coeff'] - mo_occ = numpy.zeros(mo_coeff.shape[1]) - mo_occ[:ncore] = 2 - if self.natorb: - occ = self._eig(-envs['casdm1'], ncore, nocc)[0] - mo_occ[ncore:nocc] = -occ - else: - mo_occ[ncore:nocc] = envs['casdm1'].diagonal() -# Note: mo_energy in active space =/= F_{ii} (F is general Fock) - if 'mo_energy' in envs: - mo_energy = envs['mo_energy'] - else: - mo_energy = 'None' - chkfile.dump_mcscf(self, self.chkfile, 'mcscf', envs['e_tot'], + + if envs is not None: + if self.chk_ci: + civec = envs.get('fcivec', None) + + e_tot = envs['e_tot'] + e_cas = envs['e_cas'] + casdm1 = envs['casdm1'] + if 'mo' in envs: + mo_coeff = envs['mo'] + else: + mo_coeff = envs['mo_coeff'] + mo_occ = numpy.zeros(mo_coeff.shape[1]) + mo_occ[:ncore] = 2 + if self.natorb: + occ = self._eig(-casdm1, ncore, nocc)[0] + mo_occ[ncore:nocc] = -occ + else: + mo_occ[ncore:nocc] = casdm1.diagonal() + # Note: mo_energy in active space =/= F_{ii} (F is general Fock) + if 'mo_energy' in envs: + mo_energy = envs['mo_energy'] + + chkfile.dump_mcscf(self, chk_file, 'mcscf', e_tot, mo_coeff, ncore, self.ncas, mo_occ, - mo_energy, envs['e_cas'], civec, envs['casdm1'], - overwrite_mol=False) + mo_energy, e_cas, civec, casdm1, + overwrite_mol=(envs is None)) return self def update_from_chk(self, chkfile=None): diff --git a/pyscf/mcscf/umc1step.py b/pyscf/mcscf/umc1step.py index bb8b8778a9..777799028b 100644 --- a/pyscf/mcscf/umc1step.py +++ b/pyscf/mcscf/umc1step.py @@ -27,6 +27,7 @@ import numpy import pyscf.gto import pyscf.scf +from pyscf import lib from pyscf.lib import logger from pyscf.mcscf import ucasci from pyscf.mcscf.mc1step import expmat, rotate_orb_cc, max_stepsize_scheduler, as_scanner @@ -399,6 +400,9 @@ def __init__(self, mf_or_mol, ncas=0, nelecas=0, ncore=None, frozen=None): self.converged = False self._max_stepsize = None + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('chkfile', 'callback')) + def dump_flags(self, verbose=None): log = logger.new_logger(self, verbose) log.info('') @@ -732,39 +736,54 @@ def solve_approx_ci(self, h1, h2, ci0, ecore, e_cas): ci1 += xs[i] * v[i,0] return ci1, g - def dump_chk(self, envs): - if not self.chkfile: - return self + def dump_chk(self, envs_or_file): + '''Serialize the MCSCF object and save it to the specified chkfile. - if self.chk_ci: - civec = envs['fcivec'] + Args: + envs_or_file: + If this argument is a file path, the serialized MCSCF object is + saved to the file specified by this argument. + If this attribute is a dict (created by locals()), the necessary + variables are saved to the file specified by the attribute .chkfile. + ''' + if isinstance(envs_or_file, str): + envs = None + chk_file = envs_or_file else: - civec = None + envs = envs_or_file + chk_file = self.chkfile + if not chk_file: + return self + + e_tot = mo_coeff = mo_occ = mo_energy = e_cas = civec = casdm1 = None ncore = self.ncore ncas = self.ncas nocca = ncore[0] + ncas noccb = ncore[1] + ncas - if 'mo' in envs: - mo_coeff = envs['mo'] - else: - mo_coeff = envs['mo'] - mo_occ = numpy.zeros((2,envs['mo'][0].shape[1])) - mo_occ[0,:ncore[0]] = 1 - mo_occ[1,:ncore[1]] = 1 - if self.natorb: - occa, ucas = self._eig(-envs['casdm1'][0], ncore[0], nocca) - occb, ucas = self._eig(-envs['casdm1'][1], ncore[1], noccb) - mo_occ[0,ncore[0]:nocca] = -occa - mo_occ[1,ncore[1]:noccb] = -occb - else: - mo_occ[0,ncore[0]:nocca] = envs['casdm1'][0].diagonal() - mo_occ[1,ncore[1]:noccb] = envs['casdm1'][1].diagonal() - mo_energy = 'None' - chkfile.dump_mcscf(self, self.chkfile, 'mcscf', envs['e_tot'], + if envs is not None: + if self.chk_ci: + civec = envs['fcivec'] + if 'mo' in envs: + mo_coeff = envs['mo'] + else: + mo_coeff = envs['mo_coeff'] + mo_occ = numpy.zeros((2,envs['mo'][0].shape[1])) + mo_occ[0,:ncore[0]] = 1 + mo_occ[1,:ncore[1]] = 1 + if self.natorb: + occa, ucas = self._eig(-casdm1[0], ncore[0], nocca) + occb, ucas = self._eig(-casdm1[1], ncore[1], noccb) + mo_occ[0,ncore[0]:nocca] = -occa + mo_occ[1,ncore[1]:noccb] = -occb + else: + mo_occ[0,ncore[0]:nocca] = casdm1[0].diagonal() + mo_occ[1,ncore[1]:noccb] = casdm1[1].diagonal() + + chkfile.dump_mcscf(self, self.chkfile, 'mcscf', e_tot, mo_coeff, ncore, ncas, mo_occ, - mo_energy, envs['e_cas'], civec, envs['casdm1'], - overwrite_mol=False) + mo_energy, e_cas, civec, casdm1, + overwrite_mol=(envs is None)) return self def rotate_mo(self, mo, u, log=None): diff --git a/pyscf/pbc/df/df.py b/pyscf/pbc/df/df.py index 3ac16573c4..415be72bc6 100644 --- a/pyscf/pbc/df/df.py +++ b/pyscf/pbc/df/df.py @@ -176,6 +176,9 @@ def __init__(self, cell, kpts=numpy.zeros((1,3))): self._cderi = None self._rsh_df = {} # Range separated Coulomb DF objects + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('_cderi_to_save', '_cderi', '_rsh_df'), reset_state=True) + @property def auxbasis(self): return self._auxbasis diff --git a/pyscf/pbc/df/mdf.py b/pyscf/pbc/df/mdf.py index c24fb2f184..15fd9d8b36 100644 --- a/pyscf/pbc/df/mdf.py +++ b/pyscf/pbc/df/mdf.py @@ -94,6 +94,9 @@ def __init__(self, cell, kpts=np.zeros((1,3))): self._cderi = None self._rsh_df = {} # Range separated Coulomb DF objects + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('_cderi_to_save', '_cderi', '_rsh_df'), reset_state=True) + def build(self, j_only=None, with_j3c=True, kpts_band=None): df.GDF.build(self, j_only, with_j3c, kpts_band) cell = self.cell diff --git a/pyscf/pbc/scf/hf.py b/pyscf/pbc/scf/hf.py index a7f099c712..dda8dc897a 100644 --- a/pyscf/pbc/scf/hf.py +++ b/pyscf/pbc/scf/hf.py @@ -786,9 +786,21 @@ def init_guess_by_chkfile(self, chk=None, project=None, kpt=None): def from_chk(self, chk=None, project=None, kpt=None): return self.init_guess_by_chkfile(chk, project, kpt) - def dump_chk(self, envs): - if self.chkfile: - mol_hf.SCF.dump_chk(self, envs) + def dump_chk(self, envs_or_file): + '''Serialize the SCF object and save it to the specified chkfile. + + Args: + envs_or_file: + If this argument is a file path, the serialized SCF object is + saved to the file specified by this argument. + If this attribute is a dict (created by locals()), the necessary + variables are saved to the file specified by the attribute mf.chkfile. + ''' + mol_hf.SCF.dump_chk(self, envs_or_file) + if isinstance(envs_or_file, str): + with lib.H5FileWrap(envs_or_file, 'a') as fh5: + fh5['scf/kpt'] = self.kpt + elif self.chkfile: with lib.H5FileWrap(self.chkfile, 'a') as fh5: fh5['scf/kpt'] = self.kpt return self diff --git a/pyscf/pbc/scf/khf.py b/pyscf/pbc/scf/khf.py index 02753ef270..bdd1082c04 100644 --- a/pyscf/pbc/scf/khf.py +++ b/pyscf/pbc/scf/khf.py @@ -611,9 +611,21 @@ def init_guess_by_chkfile(self, chk=None, project=None, kpts=None): def from_chk(self, chk=None, project=None, kpts=None): return self.init_guess_by_chkfile(chk, project, kpts) - def dump_chk(self, envs): - if self.chkfile: - mol_hf.SCF.dump_chk(self, envs) + def dump_chk(self, envs_or_file): + '''Serialize the SCF object and save it to the specified chkfile. + + Args: + envs_or_file: + If this argument is a file path, the serialized SCF object is + saved to the file specified by this argument. + If this attribute is a dict (created by locals()), the necessary + variables are saved to the file specified by the attribute mf.chkfile. + ''' + mol_hf.SCF.dump_chk(self, envs_or_file) + if isinstance(envs_or_file, str): + with lib.H5FileWrap(envs_or_file, 'a') as fh5: + fh5['scf/kpts'] = self.kpts + elif self.chkfile: with lib.H5FileWrap(self.chkfile, 'a') as fh5: fh5['scf/kpts'] = self.kpts return self diff --git a/pyscf/pbc/scf/rsjk.py b/pyscf/pbc/scf/rsjk.py index 43877e18b6..ce7a696812 100644 --- a/pyscf/pbc/scf/rsjk.py +++ b/pyscf/pbc/scf/rsjk.py @@ -96,6 +96,11 @@ def __init__(self, cell, kpts=np.zeros((1,3))): self._last_vs = (0, 0) self._qindex = None + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('rs_cell', 'cell_d', 'supmol_sr', 'supmol_ft', 'supmol_d', + '_sr_without_dddd', '_last_vs', '_qindex'), + reset_state=True) + def has_long_range(self): '''Whether to add the long-range part computed with AFT/FFT integrals''' return self.omega is None or abs(self.cell.omega) < self.omega @@ -123,8 +128,7 @@ def reset(self, cell=None): self.supmol_sr = None self.supmol_ft = None self.supmol_d = None - self.exclude_dd_block = True - self.approx_vk_lr_missing_mo = False + self._sr_without_dddd = None self._last_vs = (0, 0) self._qindex = None return self diff --git a/pyscf/scf/dhf.py b/pyscf/scf/dhf.py index b7172b27fd..c087f3e7bb 100644 --- a/pyscf/scf/dhf.py +++ b/pyscf/scf/dhf.py @@ -471,8 +471,7 @@ class DHF(hf.SCF): # corrections for small component when with_ssss is set to False ssss_approx = getattr(__config__, 'scf_dhf_SCF_ssss_approx', 'Visscher') - _keys = {'conv_tol', 'with_ssss', 'with_gaunt', - 'with_breit', 'ssss_approx'} + _keys = {'conv_tol', 'with_ssss', 'with_gaunt', 'with_breit', 'ssss_approx'} def __init__(self, mol): hf.SCF.__init__(self, mol) diff --git a/pyscf/scf/hf.py b/pyscf/scf/hf.py index 247e295b2b..074ce1736b 100644 --- a/pyscf/scf/hf.py +++ b/pyscf/scf/hf.py @@ -194,7 +194,7 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None, elif abs(e_tot-last_hf_e) < conv_tol and norm_gorb < conv_tol_grad: scf_conv = True - if dump_chk: + if dump_chk and mf.chkfile: mf.dump_chk(locals()) if callable(callback): @@ -228,7 +228,7 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None, scf_conv = True logger.info(mf, 'Extra cycle E= %.15g delta_E= %4.3g |g|= %4.3g |ddm|= %4.3g', e_tot, e_tot-last_hf_e, norm_gorb, norm_ddm) - if dump_chk: + if dump_chk and mf.chkfile: mf.dump_chk(locals()) logger.timer(mf, 'scf_cycle', *cput0) @@ -1646,6 +1646,9 @@ def __init__(self, mol): self._opt = {None: None} self._eri = None # Note: self._eri requires large amount of memory + __getstate__, __setstate__ = lib.generate_pickle_methods( + excludes=('chkfile', '_chkfile', '_opt', '_eri', 'callback')) + def __getattr__(self, key): '''Accessing methods post-HF methods or mean-field properties''' # Import all available modules, then retry accessing the attribute @@ -1740,8 +1743,21 @@ def get_grad(self, mo_coeff, mo_occ, fock=None): fock = self.get_hcore(self.mol) + self.get_veff(self.mol, dm1) return get_grad(mo_coeff, mo_occ, fock) - def dump_chk(self, envs): - if self.chkfile: + def dump_chk(self, envs_or_file): + '''Serialize the SCF object and save it to the specified chkfile. + + Args: + envs_or_file: + If this argument is a file path, the serialized SCF object is + saved to the file specified by this argument. + If this attribute is a dict (created by locals()), the necessary + variables are saved to the file specified by the attribute mf.chkfile. + ''' + if isinstance(envs_or_file, str): + chkfile.dump_scf(self.mol, envs_or_file, self.e_tot, self.mo_energy, + self.mo_coeff, self.mo_occ) + elif self.chkfile: + envs = envs_or_file chkfile.dump_scf(self.mol, self.chkfile, envs['e_tot'], envs['mo_energy'], envs['mo_coeff'], envs['mo_occ'],