Skip to content

Commit

Permalink
Complex dms for DFHF (fix pyscf#2670)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Jan 27, 2025
1 parent ccedc56 commit 1cd7be4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
31 changes: 29 additions & 2 deletions pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,27 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
# uses numpy.matmul
vj += dmtril.dot(eri1.T).dot(eri1)

elif dms.dtype != numpy.float64:
if with_j:
vj = numpy.zeros_like(dms)
max_memory = dfobj.max_memory - lib.current_memory()[0]
blksize = max(4, int(min(dfobj.blockdim, max_memory*.22e6/8/nao**2)))
buf = numpy.empty((blksize,nao,nao))
buf1 = numpy.empty((nao,blksize,nao))
for eri1 in dfobj.loop(blksize):
naux, nao_pair = eri1.shape
eri1 = lib.unpack_tril(eri1, out=buf)
if with_j:
tmp = numpy.einsum('pij,nji->pn', eri1, dms)
vj += numpy.einsum('pn,pij->nij', tmp, eri1)
buf2 = numpy.ndarray((nao,naux,nao), buffer=buf1)
for k in range(nset):
buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].real)
vk[k].real += lib.einsum('ipk,pkj->ij', buf2, eri1)
buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].imag)
vk[k].imag += lib.einsum('ipk,pkj->ij', buf2, eri1)
t1 = log.timer_debug1('jk', *t1)

elif getattr(dm, 'mo_coeff', None) is not None:
#TODO: test whether dm.mo_coeff matching dm
mo_coeff = numpy.asarray(dm.mo_coeff, order='F')
Expand Down Expand Up @@ -322,6 +343,7 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
buf = numpy.empty((2,blksize,nao,nao))
for eri1 in dfobj.loop(blksize):
naux, nao_pair = eri1.shape
assert (nao_pair == nao*(nao+1)//2)
if with_j:
# uses numpy.matmul
vj += dmtril.dot(eri1.T).dot(eri1)
Expand All @@ -338,8 +360,12 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
vk[k] += lib.dot(buf1.reshape(-1,nao).T, buf2.reshape(-1,nao))
t1 = log.timer_debug1('jk', *t1)

if with_j: vj = lib.unpack_tril(vj, 1).reshape(dm_shape)
if with_k: vk = vk.reshape(dm_shape)
if with_j:
if dms.dtype == numpy.float64:
vj = lib.unpack_tril(vj, 1)
vj = vj.reshape(dm_shape)
if with_k:
vk = vk.reshape(dm_shape)
logger.timer(dfobj, 'df vj and vk', *t0)
return vj, vk

Expand All @@ -348,6 +374,7 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13):
from pyscf.scf import jk
from pyscf.df import addons
t0 = t1 = (logger.process_clock(), logger.perf_counter())
assert dm.dtype == numpy.float64

mol = dfobj.mol
if dfobj._vjopt is None:
Expand Down
12 changes: 12 additions & 0 deletions pyscf/df/test/test_df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ def test_get_j(self):
self.assertAlmostEqual(abs(vj0-vj1).max(), 0, 12)
self.assertAlmostEqual(lib.fp(vj0), -194.15910890730052, 9)

def test_df_jk_complex_dm(self):
mol = gto.M(atom='H 0 0 0; H 0 0 1')
mf = mol.RHF().run()
dm = mf.make_rdm1() + 0j
dm[0,:] += .1j
dm[:,0] -= .1j
mf.kernel(dm)
self.assertTrue(mf.mo_coeff.dtype == numpy.complex128)
dfmf = mf.density_fit()
self.assertAlmostEqual(dfmf.energy_tot(), -1.0661355663696201, 9)
self.assertAlmostEqual(dfmf.energy_tot(), mf.e_tot, 3)


if __name__ == "__main__":
print("Full Tests for df")
Expand Down

0 comments on commit 1cd7be4

Please sign in to comment.