From 0ada19a297366eff0b5b0b4467e3ef290a7f9cf6 Mon Sep 17 00:00:00 2001 From: Xubo Wang <34044625+xubwa@users.noreply.github.com> Date: Fri, 18 Oct 2024 20:51:00 -0400 Subject: [PATCH] integral screening for gaunt and breit term (#2437) * integral screening for gaunt and breit term * Revert to use g_lssl integral it self to screen gaunt. Since the q_cond for g_lssl and g_lsls are not symmetry, a asymmetry q_cond is added. * mf object is reinitialized at line 241, if not, the previously initialized VHFOpt for gaunt will be used for breit as well and will thus cause wrong result. * trim whitespace --- pyscf/lib/vhf/rkb_screen.c | 139 +++++++++++++++++++++++++++++++++++++ pyscf/scf/dhf.py | 102 +++++++++++++++++++-------- pyscf/scf/test/test_dhf.py | 1 + 3 files changed, 214 insertions(+), 28 deletions(-) diff --git a/pyscf/lib/vhf/rkb_screen.c b/pyscf/lib/vhf/rkb_screen.c index 4ff0063140..223d1a64bb 100644 --- a/pyscf/lib/vhf/rkb_screen.c +++ b/pyscf/lib/vhf/rkb_screen.c @@ -32,6 +32,9 @@ #define SL 2 #define LS 3 +#define GAUNT_LL 0 +#define GAUNT_SS 1 +#define GAUNT_LS 2 // in gaunt_lssl screening, put order to ll, ss, ls int int2e_spinor(); int int2e_spsp1spsp2_spinor(); @@ -148,6 +151,61 @@ int CVHFrkbssll_vkscreen(int *shls, CVHFOpt *opt, } +int CVHFrkb_gaunt_lsls_prescreen(int *shls, CVHFOpt *opt, + int *atm, int *bas, double *env) +{ + if (opt == NULL) { + return 1; // no screen + } + int i = shls[0]; + int j = shls[1]; + int k = shls[2]; + int l = shls[3]; + int n = opt->nbas; + assert(opt->q_cond); + assert(opt->dm_cond); + assert(i < n); + assert(j < n); + assert(k < n); + assert(l < n); + double qijkl = opt->q_cond[i*n+j] * opt->q_cond[k*n+l]; + double dmin = opt->direct_scf_cutoff / qijkl; + return qijkl > opt->direct_scf_cutoff + &&((opt->dm_cond[k*n+l] > dmin) + || (opt->dm_cond[j*n+k] > dmin)); +} + + +// +int CVHFrkb_gaunt_lssl_prescreen(int *shls, CVHFOpt *opt, + int *atm, int *bas, double *env) +{ + if (opt == NULL) { + return 1; // no screen + } + int i = shls[0]; + int j = shls[1]; + int k = shls[2]; + int l = shls[3]; + int n = opt->nbas; + assert(opt->q_cond); + assert(opt->dm_cond); + assert(i < n); + assert(j < n); + assert(k < n); + assert(l < n); + double *dmll = opt->dm_cond + n*n*GAUNT_LL; + double *dmss = opt->dm_cond + n*n*GAUNT_SS; + double *dmls = opt->dm_cond + n*n*GAUNT_LS; + double qijkl = opt->q_cond[i*n+j] * opt->q_cond[k*n+l]; + double dmin = opt->direct_scf_cutoff / qijkl; + return qijkl > opt->direct_scf_cutoff + &&((dmll[j*n+k] > dmin) // dmss_ji + || (dmss[l*n+i] > dmin) // dmll_lk + || (dmls[l*n+k] > dmin)); // dmls_lk +} + + void CVHFrkb_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env) @@ -194,6 +252,67 @@ void CVHFrkb_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond, } } + +void CVHFrkb_asym_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + int shls_slice[] = {0, nbas}; + const int cache_size = GTOmax_cache_size(intor, shls_slice, 1, + atm, natm, bas, nbas, env); +#pragma omp parallel +{ + double qtmp, tmp; + int i, j, ij, di, dj, ish, jsh; + int shls[4]; + double *cache = malloc(sizeof(double) * cache_size); + di = 0; + for (ish = 0; ish < nbas; ish++) { + dj = ao_loc[ish+1] - ao_loc[ish]; + di = MAX(di, dj); + } + double complex *buf = malloc(sizeof(double complex) * di*di*di*di); +#pragma omp for schedule(dynamic, 4) + for (ij = 0; ij < nbas*(nbas+1)/2; ij++) { + ish = (int)(sqrt(2*ij+.25) - .5 + 1e-7); + jsh = ij - ish*(ish+1)/2; + di = ao_loc[ish+1] - ao_loc[ish]; + dj = ao_loc[jsh+1] - ao_loc[jsh]; + shls[0] = ish; + shls[1] = jsh; + shls[2] = ish; + shls[3] = jsh; + qtmp = 1e-100; + if (0 != (*intor)(buf, NULL, shls, atm, natm, bas, nbas, env, cintopt, cache)) { + for (i = 0; i < di; i++) { + for (j = 0; j < dj; j++) { + tmp = cabs(buf[i+di*j+di*dj*i+di*dj*di*j]); + qtmp = MAX(qtmp, tmp); + } } + qtmp = sqrt(qtmp); + } + qcond[ish*nbas+jsh] = qtmp; + shls[0] = jsh; + shls[1] = ish; + shls[2] = jsh; + shls[3] = ish; + qtmp = 1e-100; + if (0 != (*intor)(buf, NULL, shls, atm, natm, bas, nbas, env, cintopt, cache)) { + for (i = 0; i < di; i++) { + for (j = 0; j < dj; j++) { + tmp = cabs(buf[j+dj*i+dj*di*j+dj*di*dj*i]); + qtmp = MAX(qtmp, tmp); + } } + qtmp = sqrt(qtmp); + } + qcond[jsh*nbas+ish] = qtmp; + } + free(buf); + free(cache); +} +} + + void CVHFrkbllll_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env) @@ -331,6 +450,26 @@ void CVHFrkbssll_dm_cond(double *dm_cond, double complex *dm, int nset, int *ao_ } } +// the current order of dmscond (dmls, dmll, dmss) is consistent to the +// second contraction in function _call_veff_gaunt_breit in dhf.py +void CVHFrkb_gaunt_lssl_dm_cond(double *dm_cond, double complex *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env) +{ + nset = nset / 3; + int n2c = CINTtot_cgto_spinor(bas, nbas); + size_t nbas2 = nbas * nbas; + double *dmcondll = dm_cond + (1+nset)*nbas2*GAUNT_LL; + double *dmcondss = dm_cond + (1+nset)*nbas2*GAUNT_SS; + double *dmcondls = dm_cond + (1+nset)*nbas2*GAUNT_LS; + double complex *dmll = dm + n2c*n2c*GAUNT_LL*nset; + double complex *dmss = dm + n2c*n2c*GAUNT_SS*nset; + double complex *dmls = dm + n2c*n2c*GAUNT_LS*nset; + + CVHFrkb_dm_cond(dmcondll, dmll, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(dmcondss, dmss, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(dmcondls, dmls, nset, ao_loc, atm, natm, bas, nbas, env); +} + // the current order of dmscond (dmll, dmss, dmsl) is consistent to the // function _call_veff_ssll in dhf.py void CVHFrkbssll_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, diff --git a/pyscf/scf/dhf.py b/pyscf/scf/dhf.py index e8a4bb67b2..50b3ae1a68 100644 --- a/pyscf/scf/dhf.py +++ b/pyscf/scf/dhf.py @@ -61,7 +61,8 @@ def kernel(mf, conv_tol=1e-9, conv_tol_grad=None, else: dm = dm0 - mf._coulomb_level = 'LLLL' + if mf.init_guess != 'chkfile': + mf._coulomb_level = 'LLLL' cycles = 0 if dm0 is None and mf._coulomb_level.upper() == 'LLLL': scf_conv, e_tot, mo_energy, mo_coeff, mo_occ \ @@ -145,6 +146,7 @@ def get_jk_coulomb(mol, dm, hermi=1, coulomb_allow='SSSS', opt_llll=None, opt_ssll=None, opt_ssss=None, omega=None, verbose=None): log = logger.new_logger(mol, verbose) + t0 = (logger.process_clock(), logger.perf_counter()) if hermi == 0 and DEBUG: # J matrix is symmetrized in this function which is only true for @@ -155,29 +157,35 @@ def get_jk_coulomb(mol, dm, hermi=1, coulomb_allow='SSSS', if coulomb_allow.upper() == 'LLLL': log.debug('Coulomb integral: (LL|LL)') j1, k1 = _call_veff_llll(mol, dm, hermi, opt_llll) + log.timer_debug1('LLLL', *t0) n2c = j1.shape[1] vj = numpy.zeros_like(dm) vk = numpy.zeros_like(dm) - vj[...,:n2c,:n2c] = j1 - vk[...,:n2c,:n2c] = k1 + vj[..., :n2c, :n2c] = j1 + vk[..., :n2c, :n2c] = k1 elif coulomb_allow.upper() == 'SSLL' \ or coulomb_allow.upper() == 'LLSS': log.debug('Coulomb integral: (LL|LL) + (SS|LL)') vj, vk = _call_veff_ssll(mol, dm, hermi, opt_ssll) + t0 = log.timer_debug1('SSLL', *t0) j1, k1 = _call_veff_llll(mol, dm, hermi, opt_llll) + log.timer_debug1('LLLL', *t0) n2c = j1.shape[1] - vj[...,:n2c,:n2c] += j1 - vk[...,:n2c,:n2c] += k1 - else: # coulomb_allow == 'SSSS' + vj[..., :n2c, :n2c] += j1 + vk[..., :n2c, :n2c] += k1 + else: # coulomb_allow == 'SSSS' log.debug('Coulomb integral: (LL|LL) + (SS|LL) + (SS|SS)') vj, vk = _call_veff_ssll(mol, dm, hermi, opt_ssll) + t0 = log.timer_debug1('SSLL', *t0) j1, k1 = _call_veff_llll(mol, dm, hermi, opt_llll) + t0 = log.timer_debug1('LLLL', *t0) n2c = j1.shape[1] - vj[...,:n2c,:n2c] += j1 - vk[...,:n2c,:n2c] += k1 + vj[..., :n2c, :n2c] += j1 + vk[..., :n2c, :n2c] += k1 j1, k1 = _call_veff_ssss(mol, dm, hermi, opt_ssss) - vj[...,n2c:,n2c:] += j1 - vk[...,n2c:,n2c:] += k1 + log.timer_debug1('SSSS', *t0) + vj[..., n2c:, n2c:] += j1 + vk[..., n2c:, n2c:] += k1 return vj, vk get_jk = get_jk_coulomb @@ -477,6 +485,7 @@ class DHF(hf.SCF): with_breit = getattr(__config__, 'scf_dhf_SCF_with_breit', False) # corrections for small component when with_ssss is set to False ssss_approx = getattr(__config__, 'scf_dhf_SCF_ssss_approx', 'Visscher') + screening = True _keys = {'conv_tol', 'with_ssss', 'with_gaunt', 'with_breit', 'ssss_approx'} @@ -605,11 +614,34 @@ def set_vkscreen(opt, name): direct_scf_tol=self.direct_scf_tol) opt_ssll.q_cond = numpy.array([opt_llll.q_cond, opt_ssss.q_cond]) set_vkscreen(opt_ssll, 'CVHFrkbssll_vkscreen') + logger.timer(self, 'init_direct_scf_coulomb', *cpu0) + + opt_gaunt_lsls = None + opt_gaunt_lssl = None #TODO: prescreen for gaunt - opt_gaunt = None - logger.timer(self, 'init_direct_scf', *cpu0) - return opt_llll, opt_ssll, opt_ssss, opt_gaunt + if self.with_gaunt: + if self.with_breit: + # integral function int2e_breit_ssp1ssp2_spinor evaluates + # -1/2[alpha1*alpha2/r12 + (alpha1*r12)(alpha2*r12)/r12^3] + intor_prefix = 'int2e_breit_' + else: + # integral function int2e_ssp1ssp2_spinor evaluates only + # alpha1*alpha2/r12. Minus sign was not included. + intor_prefix = 'int2e_' + opt_gaunt_lsls = _VHFOpt(mol, intor_prefix + 'ssp1ssp2_spinor', + 'CVHFrkb_gaunt_lsls_prescreen', 'CVHFrkb_asym_q_cond', + 'CVHFrkb_dm_cond', + direct_scf_tol=self.direct_scf_tol/c1**2) + + opt_gaunt_lssl = _VHFOpt(mol, intor_prefix + 'ssp1sps2_spinor', + 'CVHFrkb_gaunt_lssl_prescreen', 'CVHFrkb_asym_q_cond', + 'CVHFrkb_dm_cond', + direct_scf_tol=self.direct_scf_tol/c1**2) + + logger.timer(self, 'init_direct_scf_gaunt_breit', *cpu0) + #return None, None, None, None, None + return opt_llll, opt_ssll, opt_ssss, opt_gaunt_lsls, opt_gaunt_lssl def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, omega=None): @@ -622,13 +654,16 @@ def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, self._opt[omega] = self.init_direct_scf(mol) vhfopt = self._opt.get(omega) if vhfopt is None: - opt_llll = opt_ssll = opt_ssss = opt_gaunt = None + opt_llll = opt_ssll = opt_ssss = opt_gaunt_lsls = opt_gaunt_lssl = None else: - opt_llll, opt_ssll, opt_ssss, opt_gaunt = vhfopt + opt_llll, opt_ssll, opt_ssss, opt_gaunt_lsls, opt_gaunt_lssl = vhfopt + if self.screening is False: + opt_llll = opt_ssll = opt_ssss = opt_gaunt_lsls = opt_gaunt_lssl = None + opt_gaunt = (opt_gaunt_lsls, opt_gaunt_lssl) vj, vk = get_jk_coulomb(mol, dm, hermi, self._coulomb_level, opt_llll, opt_ssll, opt_ssss, omega, log) - + t1 = log.timer_debug1('Coulomb', *t0) if self.with_breit: assert omega is None if ('SSSS' in self._coulomb_level.upper() or @@ -644,6 +679,7 @@ def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, vj1, vk1 = _call_veff_gaunt_breit(mol, dm, hermi, opt_gaunt, False) vj += vj1 vk += vk1 + log.timer_debug1('Gaunt and Breit term', *t1) log.timer('vj and vk', *t0) return vj, vk @@ -927,6 +963,14 @@ def _call_veff_ssss(mol, dm, hermi=1, mf_opt=None): return _jk_triu_(mol, vj, vk, hermi) def _call_veff_gaunt_breit(mol, dm, hermi=1, mf_opt=None, with_breit=False): + if mf_opt is not None: + opt_gaunt_lsls, opt_gaunt_lssl = mf_opt + else: + opt_gaunt_lsls = opt_gaunt_lssl = None + + log = logger.new_logger(mol) + t0 = (logger.process_clock(), logger.perf_counter()) + if with_breit: # integral function int2e_breit_ssp1ssp2_spinor evaluates # -1/2[alpha1*alpha2/r12 + (alpha1*r12)(alpha2*r12)/r12^3] @@ -946,7 +990,7 @@ def _call_veff_gaunt_breit(mol, dm, hermi=1, mf_opt=None, with_breit=False): dmsl = dm[n2c:,:n2c].copy() dmll = dm[:n2c,:n2c].copy() dmss = dm[n2c:,n2c:].copy() - dms = [dmsl, dmsl, dmls, dmll, dmss] + dms = [dmsl, dmls, dmll, dmss] else: n_dm = len(dm) n2c = dm[0].shape[0] // 2 @@ -954,22 +998,24 @@ def _call_veff_gaunt_breit(mol, dm, hermi=1, mf_opt=None, with_breit=False): dmls = [dmi[:n2c,n2c:].copy() for dmi in dm] dmsl = [dmi[n2c:,:n2c].copy() for dmi in dm] dmss = [dmi[n2c:,n2c:].copy() for dmi in dm] - dms = dmsl + dmsl + dmls + dmll + dmss + dms = dmsl + dmls + dmll + dmss vj = numpy.zeros((n_dm,n2c*2,n2c*2), dtype=numpy.complex128) vk = numpy.zeros((n_dm,n2c*2,n2c*2), dtype=numpy.complex128) - jks = ('lk->s1ij',) * n_dm \ - + ('jk->s1il',) * n_dm - vx = _vhf.rdirect_bindm(intor_prefix+'ssp1ssp2_spinor', 's1', jks, dms[:n_dm*2], 1, - mol._atm, mol._bas, mol._env, mf_opt) - vj[:,:n2c,n2c:] = vx[:n_dm,:,:] - vk[:,:n2c,n2c:] = vx[n_dm:,:,:] + jks = ('lk->s1ij', 'jk->s1il') + vj_ls, vk_ls = _vhf.rdirect_mapdm(intor_prefix+'ssp1ssp2_spinor', 's1', jks, dms[:n_dm], 1, + mol._atm, mol._bas, mol._env, opt_gaunt_lsls) + vj[:,:n2c,n2c:] = vj_ls + vk[:,:n2c,n2c:] = vk_ls + t0 = log.timer_debug1('LSLS contribution', *t0) jks = ('lk->s1ij',) * n_dm \ + ('li->s1kj',) * n_dm \ + ('jk->s1il',) * n_dm - vx = _vhf.rdirect_bindm(intor_prefix+'ssp1sps2_spinor', 's1', jks, dms[n_dm*2:], 1, - mol._atm, mol._bas, mol._env, mf_opt) + vx = _vhf.rdirect_bindm(intor_prefix+'ssp1sps2_spinor', 's1', jks, dms[n_dm:], 1, + mol._atm, mol._bas, mol._env, opt_gaunt_lssl) + + t0 = log.timer_debug1('LSSL contribution', *t0) vj[:,:n2c,n2c:]+= vx[ :n_dm ,:,:] vk[:,n2c:,n2c:] = vx[n_dm :n_dm*2,:,:] vk[:,:n2c,:n2c] = vx[n_dm*2: ,:,:] @@ -1045,8 +1091,8 @@ def set_dm(self, dm, atm, bas, env): (1, 0, (1, 1)), ]} mol.build() -############## -# SCF result + ############## + # SCF result method = UHF(mol) energy = method.scf() #-2.38146942868 print(energy) diff --git a/pyscf/scf/test/test_dhf.py b/pyscf/scf/test/test_dhf.py index cde37f4b0a..07ee57229f 100644 --- a/pyscf/scf/test/test_dhf.py +++ b/pyscf/scf/test/test_dhf.py @@ -238,6 +238,7 @@ def test_get_jk_with_gaunt_breit_high_cost(self): vj0 = numpy.einsum('ijkl,xlk->xij', eri1, dm) vk0 = numpy.einsum('ijkl,xjk->xil', eri1, dm) + mf = scf.dhf.DHF(h4) mf.with_breit = True vj1, vk1 = mf.get_jk(h4, dm, hermi=1) self.assertTrue(numpy.allclose(vj0, vj1))