Skip to content

Commit

Permalink
Use self.ci if set for mcscf (pyscf#2342)
Browse files Browse the repository at this point in the history
* use casscf.ci if set

* also fix the free kernel

* update test

* remove global mols

* update test to align with casci

* Improve error types in FCI modules

---------

Co-authored-by: Qiming Sun <[email protected]>
  • Loading branch information
matthew-hennefarth and sunqm authored Aug 5, 2024
1 parent a9c3e40 commit ab210cd
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 15 deletions.
6 changes: 4 additions & 2 deletions pyscf/fci/direct_spin0_symm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def get_init_guess(norb, nelec, nroots, hdiag, orbsym, wfnsym=0):
ci0.append(x.ravel().view(direct_spin1.FCIvector))

if len(ci0) == 0:
raise RuntimeError(f'Initial guess for symmetry {wfnsym} not found')
raise lib.exceptions.WfnSymmetryError(
f'Initial guess for symmetry {wfnsym} not found')
return ci0

def get_init_guess_cyl_sym(norb, nelec, nroots, hdiag, orbsym, wfnsym=0):
Expand Down Expand Up @@ -146,7 +147,8 @@ def get_init_guess_cyl_sym(norb, nelec, nroots, hdiag, orbsym, wfnsym=0):
break

if len(ci0) == 0:
raise RuntimeError(f'Initial guess for symmetry {wfnsym} not found')
raise lib.exceptions.WfnSymmetryError(
f'Initial guess for symmetry {wfnsym} not found')
return ci0


Expand Down
20 changes: 12 additions & 8 deletions pyscf/fci/direct_spin1_symm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def _get_init_guess(airreps, birreps, nroots, hdiag, nelec, orbsym, wfnsym=0):
ci0.append(x.ravel().view(direct_spin1.FCIvector))

if len(ci0) == 0:
raise RuntimeError(f'Initial guess for symmetry {wfnsym} not found')
raise lib.exceptions.WfnSymmetryError(
f'Initial guess for symmetry {wfnsym} not found')
return ci0

def get_init_guess(norb, nelec, nroots, hdiag, orbsym, wfnsym=0):
Expand All @@ -263,7 +264,8 @@ def get_init_guess(norb, nelec, nroots, hdiag, orbsym, wfnsym=0):
ci0.append(x.ravel().view(direct_spin1.FCIvector))

if len(ci0) == 0:
raise RuntimeError(f'Initial guess for symmetry {wfnsym} not found')
raise lib.exceptions.WfnSymmetryError(
f'Initial guess for symmetry {wfnsym} not found')
return ci0

def _validate_degen_mapping(mapping, norb):
Expand Down Expand Up @@ -355,7 +357,8 @@ def get_init_guess_cyl_sym(norb, nelec, nroots, hdiag, orbsym, wfnsym=0):
break

if len(ci0) == 0:
raise RuntimeError(f'Initial guess for symmetry {wfnsym} not found')
raise lib.exceptions.WfnSymmetryError(
f'Initial guess for symmetry {wfnsym} not found')
return ci0

def _cyl_sym_csf2civec(strs, addr, orbsym, degen_mapping):
Expand Down Expand Up @@ -598,8 +601,9 @@ def guess_wfnsym(solver, norb, nelec, fcivec=None, orbsym=None, wfnsym=None, **k
fcivec = fcivec[0]
wfnsym1 = _guess_wfnsym_cyl_sym(fcivec, strsa, strsb, orbsym)
if wfnsym1 != _id_wfnsym(solver, norb, nelec, orbsym, wfnsym):
raise RuntimeError(f'Input wfnsym {wfnsym} is not consistent with '
f'fcivec symmetry {wfnsym1}')
raise lib.exceptions.WfnSymmetryError(
f'Input wfnsym {wfnsym} is not consistent with '
f'fcivec symmetry {wfnsym1}')
wfnsym = wfnsym1
else:
na, nb = strsa.size, strsb.size
Expand All @@ -617,8 +621,8 @@ def guess_wfnsym(solver, norb, nelec, fcivec=None, orbsym=None, wfnsym=None, **k
if isinstance(fcivec, np.ndarray) and fcivec.ndim <= 2:
fcivec = [fcivec]
if all(abs(c.reshape(na, nb)[mask]).max() < 1e-5 for c in fcivec):
raise RuntimeError('Input wfnsym {wfnsym} is not consistent with '
'fcivec coefficients')
raise lib.exceptions.WfnSymmetryError(
'Input wfnsym {wfnsym} is not consistent with fcivec coefficients')
return wfnsym

def sym_allowed_indices(nelec, orbsym, wfnsym):
Expand Down Expand Up @@ -787,7 +791,7 @@ def kernel(self, h1e, eri, norb, nelec, ci0=None,
logger.debug(self, 'Num symmetry allowed elements %d',
sum([x.size for x in self.sym_allowed_idx]))
if s_idx.size == 0:
raise RuntimeError(
raise lib.exceptions.WfnSymmetryError(
f'Symmetry allowed determinants not found for wfnsym {wfnsym}')

if wfnsym_ir > 7:
Expand Down
3 changes: 3 additions & 0 deletions pyscf/lib/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ class BasisNotFoundError(RuntimeError):

class PointGroupSymmetryError(RuntimeError):
pass

class WfnSymmetryError(RuntimeError):
pass
2 changes: 2 additions & 0 deletions pyscf/mcscf/casci.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ def kernel(casci, mo_coeff=None, ci0=None, verbose=logger.NOTE, envs=None):
and "envs" pop in kernel function
'''
if mo_coeff is None: mo_coeff = casci.mo_coeff
if ci0 is None: ci0 = casci.ci

log = logger.new_logger(casci, verbose)
t0 = (logger.process_clock(), logger.perf_counter())
log.debug('Start CASCI')
Expand Down
6 changes: 6 additions & 0 deletions pyscf/mcscf/mc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None,
if callback is None:
callback = casscf.callback

if ci0 is None:
ci0 = casscf.ci

mo = mo_coeff
nmo = mo_coeff.shape[1]
ncore = casscf.ncore
Expand Down Expand Up @@ -853,6 +856,9 @@ def kernel(self, mo_coeff=None, ci0=None, callback=None, _kern=kernel):
self.mo_coeff = mo_coeff
if callback is None: callback = self.callback

if ci0 is None:
ci0 = self.ci

self.check_sanity()
self.dump_flags()

Expand Down
10 changes: 5 additions & 5 deletions pyscf/mcscf/test/test_mc1step.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUpModule():
global mol, molsym, m, msym, mc0
b = 1.4
mol = gto.M(
verbose = 5,
verbose = 0,
output = '/dev/null',
atom = [
['N',( 0.000000, 0.000000, -b/2)],
Expand All @@ -42,7 +42,7 @@ def setUpModule():
mc0 = mcscf.CASSCF(m, 4, 4).run()

molsym = gto.M(
verbose = 5,
verbose = 0,
output = '/dev/null',
atom = [
['N',( 0.000000, 0.000000, -b/2)],
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_mc1step_symm_with_x2c_scanner(self):
self.assertAlmostEqual(mc1.e_tot, -109.02535605303684, 7)

def test_0core_0virtual(self):
mol = gto.M(atom='He', basis='321g')
mol = gto.M(atom='He', basis='321g', verbose=0)
mf = scf.RHF(mol).run()
mc1 = mcscf.CASSCF(mf, 2, 2).run()
self.assertAlmostEqual(mc1.e_tot, -2.850576699649737, 9)
Expand Down Expand Up @@ -316,14 +316,14 @@ def test_state_average_mix(self):
mc.analyze()
mo_coeff, civec, mo_occ = mc.cas_natorb(sort=True)

mc.kernel(mo_coeff=mo_coeff)
mc.kernel(mo_coeff=mo_coeff, ci0=civec)
self.assertAlmostEqual(mc.e_states[0], -108.7506795311190, 5)
self.assertAlmostEqual(mc.e_states[1], -108.8582272809495, 5)
self.assertAlmostEqual(abs((civec[0]*mc.ci[0]).sum()), 1, 7)
self.assertAlmostEqual(abs((civec[1]*mc.ci[1]).sum()), 1, 7)

def test_small_system(self):
mol = gto.M(atom='H 0 0 0; H 0 0 .74', symmetry=True, basis='6-31g')
mol = gto.M(atom='H 0 0 0; H 0 0 .74', symmetry=True, basis='6-31g', verbose=0)
mf = scf.RHF(mol).run()
mc = mcscf.CASSCF(mf, 2, 2)
mc.max_cycle = 5
Expand Down
4 changes: 4 additions & 0 deletions pyscf/mcscf/test/test_n2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy
from pyscf import gto
from pyscf import scf
from pyscf import lib
from pyscf import mcscf

def setUpModule():
Expand Down Expand Up @@ -166,6 +167,9 @@ def test_wfnsym(self):
self.assertAlmostEqual(emc, -108.74508322877787, 7)

mc.wfnsym = 'A2u'
with self.assertRaises(lib.exceptions.WfnSymmetryError):
mc.mc1step()
mc.ci = None
emc = mc.mc1step()[0]
self.assertAlmostEqual(emc, -108.69019443475308, 7)

Expand Down

0 comments on commit ab210cd

Please sign in to comment.