From 1cd7be46aa7187ea768e39958f5c3e19b7715d54 Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Sun, 26 Jan 2025 16:24:44 -0800 Subject: [PATCH] Complex dms for DFHF (fix #2670) --- pyscf/df/df_jk.py | 31 +++++++++++++++++++++++++++++-- pyscf/df/test/test_df_jk.py | 12 ++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/pyscf/df/df_jk.py b/pyscf/df/df_jk.py index 9e5670537f..2cd630b3c7 100644 --- a/pyscf/df/df_jk.py +++ b/pyscf/df/df_jk.py @@ -268,6 +268,27 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): # uses numpy.matmul vj += dmtril.dot(eri1.T).dot(eri1) + elif dms.dtype != numpy.float64: + if with_j: + vj = numpy.zeros_like(dms) + max_memory = dfobj.max_memory - lib.current_memory()[0] + blksize = max(4, int(min(dfobj.blockdim, max_memory*.22e6/8/nao**2))) + buf = numpy.empty((blksize,nao,nao)) + buf1 = numpy.empty((nao,blksize,nao)) + for eri1 in dfobj.loop(blksize): + naux, nao_pair = eri1.shape + eri1 = lib.unpack_tril(eri1, out=buf) + if with_j: + tmp = numpy.einsum('pij,nji->pn', eri1, dms) + vj += numpy.einsum('pn,pij->nij', tmp, eri1) + buf2 = numpy.ndarray((nao,naux,nao), buffer=buf1) + for k in range(nset): + buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].real) + vk[k].real += lib.einsum('ipk,pkj->ij', buf2, eri1) + buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].imag) + vk[k].imag += lib.einsum('ipk,pkj->ij', buf2, eri1) + t1 = log.timer_debug1('jk', *t1) + elif getattr(dm, 'mo_coeff', None) is not None: #TODO: test whether dm.mo_coeff matching dm mo_coeff = numpy.asarray(dm.mo_coeff, order='F') @@ -322,6 +343,7 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): buf = numpy.empty((2,blksize,nao,nao)) for eri1 in dfobj.loop(blksize): naux, nao_pair = eri1.shape + assert (nao_pair == nao*(nao+1)//2) if with_j: # uses numpy.matmul vj += dmtril.dot(eri1.T).dot(eri1) @@ -338,8 +360,12 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): vk[k] += lib.dot(buf1.reshape(-1,nao).T, buf2.reshape(-1,nao)) t1 = log.timer_debug1('jk', *t1) - if with_j: vj = lib.unpack_tril(vj, 1).reshape(dm_shape) - if with_k: vk = vk.reshape(dm_shape) + if with_j: + if dms.dtype == numpy.float64: + vj = lib.unpack_tril(vj, 1) + vj = vj.reshape(dm_shape) + if with_k: + vk = vk.reshape(dm_shape) logger.timer(dfobj, 'df vj and vk', *t0) return vj, vk @@ -348,6 +374,7 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13): from pyscf.scf import jk from pyscf.df import addons t0 = t1 = (logger.process_clock(), logger.perf_counter()) + assert dm.dtype == numpy.float64 mol = dfobj.mol if dfobj._vjopt is None: diff --git a/pyscf/df/test/test_df_jk.py b/pyscf/df/test/test_df_jk.py index e418a70204..d426430ae3 100644 --- a/pyscf/df/test/test_df_jk.py +++ b/pyscf/df/test/test_df_jk.py @@ -194,6 +194,18 @@ def test_get_j(self): self.assertAlmostEqual(abs(vj0-vj1).max(), 0, 12) self.assertAlmostEqual(lib.fp(vj0), -194.15910890730052, 9) + def test_df_jk_complex_dm(self): + mol = gto.M(atom='H 0 0 0; H 0 0 1') + mf = mol.RHF().run() + dm = mf.make_rdm1() + 0j + dm[0,:] += .1j + dm[:,0] -= .1j + mf.kernel(dm) + self.assertTrue(mf.mo_coeff.dtype == numpy.complex128) + dfmf = mf.density_fit() + self.assertAlmostEqual(dfmf.energy_tot(), -1.0661355663696201, 9) + self.assertAlmostEqual(dfmf.energy_tot(), mf.e_tot, 3) + if __name__ == "__main__": print("Full Tests for df")