From 1e8540f7f32c8b765d44e0e2528745200fcb4074 Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Sun, 26 Jan 2025 20:49:36 -0800 Subject: [PATCH] Fix dimension issue --- pyscf/df/df_jk.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/pyscf/df/df_jk.py b/pyscf/df/df_jk.py index 2cd630b3c7..74de84ceef 100644 --- a/pyscf/df/df_jk.py +++ b/pyscf/df/df_jk.py @@ -258,17 +258,7 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): vj = 0 vk = numpy.zeros_like(dms) - if with_j: - idx = numpy.arange(nao) - dmtril = lib.pack_tril(dms + dms.conj().transpose(0,2,1)) - dmtril[:,idx*(idx+1)//2+idx] *= .5 - - if not with_k: - for eri1 in dfobj.loop(): - # uses numpy.matmul - vj += dmtril.dot(eri1.T).dot(eri1) - - elif dms.dtype != numpy.float64: + if dms.dtype != numpy.float64: if with_j: vj = numpy.zeros_like(dms) max_memory = dfobj.max_memory - lib.current_memory()[0] @@ -279,8 +269,10 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): naux, nao_pair = eri1.shape eri1 = lib.unpack_tril(eri1, out=buf) if with_j: - tmp = numpy.einsum('pij,nji->pn', eri1, dms) - vj += numpy.einsum('pn,pij->nij', tmp, eri1) + tmp = numpy.einsum('pij,nji->pn', eri1, dms.real) + vj.real += numpy.einsum('pn,pij->nij', tmp, eri1) + tmp = numpy.einsum('pij,nji->pn', eri1, dms.imag) + vj.imag += numpy.einsum('pn,pij->nij', tmp, eri1) buf2 = numpy.ndarray((nao,naux,nao), buffer=buf1) for k in range(nset): buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].real) @@ -288,6 +280,20 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].imag) vk[k].imag += lib.einsum('ipk,pkj->ij', buf2, eri1) t1 = log.timer_debug1('jk', *t1) + if with_j: vj = vj.reshape(dm_shape) + if with_k: vk = vk.reshape(dm_shape) + logger.timer(dfobj, 'df vj and vk', *t0) + return vj, vk + + if with_j: + idx = numpy.arange(nao) + dmtril = lib.pack_tril(dms + dms.conj().transpose(0,2,1)) + dmtril[:,idx*(idx+1)//2+idx] *= .5 + + if not with_k: + for eri1 in dfobj.loop(): + # uses numpy.matmul + vj += dmtril.dot(eri1.T).dot(eri1) elif getattr(dm, 'mo_coeff', None) is not None: #TODO: test whether dm.mo_coeff matching dm @@ -360,12 +366,8 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13): vk[k] += lib.dot(buf1.reshape(-1,nao).T, buf2.reshape(-1,nao)) t1 = log.timer_debug1('jk', *t1) - if with_j: - if dms.dtype == numpy.float64: - vj = lib.unpack_tril(vj, 1) - vj = vj.reshape(dm_shape) - if with_k: - vk = vk.reshape(dm_shape) + if with_j: vj = lib.unpack_tril(vj, 1).reshape(dm_shape) + if with_k: vk = vk.reshape(dm_shape) logger.timer(dfobj, 'df vj and vk', *t0) return vj, vk @@ -374,7 +376,6 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13): from pyscf.scf import jk from pyscf.df import addons t0 = t1 = (logger.process_clock(), logger.perf_counter()) - assert dm.dtype == numpy.float64 mol = dfobj.mol if dfobj._vjopt is None: @@ -415,6 +416,7 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13): opt = dfobj._vjopt fakemol = opt.fakemol dm = numpy.asarray(dm, order='C') + assert dm.dtype == numpy.float64 dm_shape = dm.shape nao = dm_shape[-1] dm = dm.reshape(-1,nao,nao)