Skip to content

Commit

Permalink
integral screening for gaunt and breit term (pyscf#2437)
Browse files Browse the repository at this point in the history
* integral screening for gaunt and breit term

* Revert to use g_lssl integral it self to screen gaunt.
Since the q_cond for g_lssl and g_lsls are not symmetry, a asymmetry q_cond is added.

* mf object is reinitialized at line 241, if not, the previously initialized VHFOpt for gaunt will be used for breit as well and will thus cause wrong result.

* trim whitespace
  • Loading branch information
xubwa authored Oct 19, 2024
1 parent c305a63 commit 0ada19a
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 28 deletions.
139 changes: 139 additions & 0 deletions pyscf/lib/vhf/rkb_screen.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#define SL 2
#define LS 3

#define GAUNT_LL 0
#define GAUNT_SS 1
#define GAUNT_LS 2 // in gaunt_lssl screening, put order to ll, ss, ls

int int2e_spinor();
int int2e_spsp1spsp2_spinor();
Expand Down Expand Up @@ -148,6 +151,61 @@ int CVHFrkbssll_vkscreen(int *shls, CVHFOpt *opt,
}


int CVHFrkb_gaunt_lsls_prescreen(int *shls, CVHFOpt *opt,
int *atm, int *bas, double *env)
{
if (opt == NULL) {
return 1; // no screen
}
int i = shls[0];
int j = shls[1];
int k = shls[2];
int l = shls[3];
int n = opt->nbas;
assert(opt->q_cond);
assert(opt->dm_cond);
assert(i < n);
assert(j < n);
assert(k < n);
assert(l < n);
double qijkl = opt->q_cond[i*n+j] * opt->q_cond[k*n+l];
double dmin = opt->direct_scf_cutoff / qijkl;
return qijkl > opt->direct_scf_cutoff
&&((opt->dm_cond[k*n+l] > dmin)
|| (opt->dm_cond[j*n+k] > dmin));
}


//
int CVHFrkb_gaunt_lssl_prescreen(int *shls, CVHFOpt *opt,
int *atm, int *bas, double *env)
{
if (opt == NULL) {
return 1; // no screen
}
int i = shls[0];
int j = shls[1];
int k = shls[2];
int l = shls[3];
int n = opt->nbas;
assert(opt->q_cond);
assert(opt->dm_cond);
assert(i < n);
assert(j < n);
assert(k < n);
assert(l < n);
double *dmll = opt->dm_cond + n*n*GAUNT_LL;
double *dmss = opt->dm_cond + n*n*GAUNT_SS;
double *dmls = opt->dm_cond + n*n*GAUNT_LS;
double qijkl = opt->q_cond[i*n+j] * opt->q_cond[k*n+l];
double dmin = opt->direct_scf_cutoff / qijkl;
return qijkl > opt->direct_scf_cutoff
&&((dmll[j*n+k] > dmin) // dmss_ji
|| (dmss[l*n+i] > dmin) // dmll_lk
|| (dmls[l*n+k] > dmin)); // dmls_lk
}


void CVHFrkb_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
Expand Down Expand Up @@ -194,6 +252,67 @@ void CVHFrkb_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond,
}
}


void CVHFrkb_asym_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
{
int shls_slice[] = {0, nbas};
const int cache_size = GTOmax_cache_size(intor, shls_slice, 1,
atm, natm, bas, nbas, env);
#pragma omp parallel
{
double qtmp, tmp;
int i, j, ij, di, dj, ish, jsh;
int shls[4];
double *cache = malloc(sizeof(double) * cache_size);
di = 0;
for (ish = 0; ish < nbas; ish++) {
dj = ao_loc[ish+1] - ao_loc[ish];
di = MAX(di, dj);
}
double complex *buf = malloc(sizeof(double complex) * di*di*di*di);
#pragma omp for schedule(dynamic, 4)
for (ij = 0; ij < nbas*(nbas+1)/2; ij++) {
ish = (int)(sqrt(2*ij+.25) - .5 + 1e-7);
jsh = ij - ish*(ish+1)/2;
di = ao_loc[ish+1] - ao_loc[ish];
dj = ao_loc[jsh+1] - ao_loc[jsh];
shls[0] = ish;
shls[1] = jsh;
shls[2] = ish;
shls[3] = jsh;
qtmp = 1e-100;
if (0 != (*intor)(buf, NULL, shls, atm, natm, bas, nbas, env, cintopt, cache)) {
for (i = 0; i < di; i++) {
for (j = 0; j < dj; j++) {
tmp = cabs(buf[i+di*j+di*dj*i+di*dj*di*j]);
qtmp = MAX(qtmp, tmp);
} }
qtmp = sqrt(qtmp);
}
qcond[ish*nbas+jsh] = qtmp;
shls[0] = jsh;
shls[1] = ish;
shls[2] = jsh;
shls[3] = ish;
qtmp = 1e-100;
if (0 != (*intor)(buf, NULL, shls, atm, natm, bas, nbas, env, cintopt, cache)) {
for (i = 0; i < di; i++) {
for (j = 0; j < dj; j++) {
tmp = cabs(buf[j+dj*i+dj*di*j+dj*di*dj*i]);
qtmp = MAX(qtmp, tmp);
} }
qtmp = sqrt(qtmp);
}
qcond[jsh*nbas+ish] = qtmp;
}
free(buf);
free(cache);
}
}


void CVHFrkbllll_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
Expand Down Expand Up @@ -331,6 +450,26 @@ void CVHFrkbssll_dm_cond(double *dm_cond, double complex *dm, int nset, int *ao_
}
}

// the current order of dmscond (dmls, dmll, dmss) is consistent to the
// second contraction in function _call_veff_gaunt_breit in dhf.py
void CVHFrkb_gaunt_lssl_dm_cond(double *dm_cond, double complex *dm, int nset, int *ao_loc,
int *atm, int natm, int *bas, int nbas, double *env)
{
nset = nset / 3;
int n2c = CINTtot_cgto_spinor(bas, nbas);
size_t nbas2 = nbas * nbas;
double *dmcondll = dm_cond + (1+nset)*nbas2*GAUNT_LL;
double *dmcondss = dm_cond + (1+nset)*nbas2*GAUNT_SS;
double *dmcondls = dm_cond + (1+nset)*nbas2*GAUNT_LS;
double complex *dmll = dm + n2c*n2c*GAUNT_LL*nset;
double complex *dmss = dm + n2c*n2c*GAUNT_SS*nset;
double complex *dmls = dm + n2c*n2c*GAUNT_LS*nset;

CVHFrkb_dm_cond(dmcondll, dmll, nset, ao_loc, atm, natm, bas, nbas, env);
CVHFrkb_dm_cond(dmcondss, dmss, nset, ao_loc, atm, natm, bas, nbas, env);
CVHFrkb_dm_cond(dmcondls, dmls, nset, ao_loc, atm, natm, bas, nbas, env);
}

// the current order of dmscond (dmll, dmss, dmsl) is consistent to the
// function _call_veff_ssll in dhf.py
void CVHFrkbssll_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset,
Expand Down
102 changes: 74 additions & 28 deletions pyscf/scf/dhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def kernel(mf, conv_tol=1e-9, conv_tol_grad=None,
else:
dm = dm0

mf._coulomb_level = 'LLLL'
if mf.init_guess != 'chkfile':
mf._coulomb_level = 'LLLL'
cycles = 0
if dm0 is None and mf._coulomb_level.upper() == 'LLLL':
scf_conv, e_tot, mo_energy, mo_coeff, mo_occ \
Expand Down Expand Up @@ -145,6 +146,7 @@ def get_jk_coulomb(mol, dm, hermi=1, coulomb_allow='SSSS',
opt_llll=None, opt_ssll=None, opt_ssss=None,
omega=None, verbose=None):
log = logger.new_logger(mol, verbose)
t0 = (logger.process_clock(), logger.perf_counter())

if hermi == 0 and DEBUG:
# J matrix is symmetrized in this function which is only true for
Expand All @@ -155,29 +157,35 @@ def get_jk_coulomb(mol, dm, hermi=1, coulomb_allow='SSSS',
if coulomb_allow.upper() == 'LLLL':
log.debug('Coulomb integral: (LL|LL)')
j1, k1 = _call_veff_llll(mol, dm, hermi, opt_llll)
log.timer_debug1('LLLL', *t0)
n2c = j1.shape[1]
vj = numpy.zeros_like(dm)
vk = numpy.zeros_like(dm)
vj[...,:n2c,:n2c] = j1
vk[...,:n2c,:n2c] = k1
vj[..., :n2c, :n2c] = j1
vk[..., :n2c, :n2c] = k1
elif coulomb_allow.upper() == 'SSLL' \
or coulomb_allow.upper() == 'LLSS':
log.debug('Coulomb integral: (LL|LL) + (SS|LL)')
vj, vk = _call_veff_ssll(mol, dm, hermi, opt_ssll)
t0 = log.timer_debug1('SSLL', *t0)
j1, k1 = _call_veff_llll(mol, dm, hermi, opt_llll)
log.timer_debug1('LLLL', *t0)
n2c = j1.shape[1]
vj[...,:n2c,:n2c] += j1
vk[...,:n2c,:n2c] += k1
else: # coulomb_allow == 'SSSS'
vj[..., :n2c, :n2c] += j1
vk[..., :n2c, :n2c] += k1
else: # coulomb_allow == 'SSSS'
log.debug('Coulomb integral: (LL|LL) + (SS|LL) + (SS|SS)')
vj, vk = _call_veff_ssll(mol, dm, hermi, opt_ssll)
t0 = log.timer_debug1('SSLL', *t0)
j1, k1 = _call_veff_llll(mol, dm, hermi, opt_llll)
t0 = log.timer_debug1('LLLL', *t0)
n2c = j1.shape[1]
vj[...,:n2c,:n2c] += j1
vk[...,:n2c,:n2c] += k1
vj[..., :n2c, :n2c] += j1
vk[..., :n2c, :n2c] += k1
j1, k1 = _call_veff_ssss(mol, dm, hermi, opt_ssss)
vj[...,n2c:,n2c:] += j1
vk[...,n2c:,n2c:] += k1
log.timer_debug1('SSSS', *t0)
vj[..., n2c:, n2c:] += j1
vk[..., n2c:, n2c:] += k1

return vj, vk
get_jk = get_jk_coulomb
Expand Down Expand Up @@ -477,6 +485,7 @@ class DHF(hf.SCF):
with_breit = getattr(__config__, 'scf_dhf_SCF_with_breit', False)
# corrections for small component when with_ssss is set to False
ssss_approx = getattr(__config__, 'scf_dhf_SCF_ssss_approx', 'Visscher')
screening = True

_keys = {'conv_tol', 'with_ssss', 'with_gaunt', 'with_breit', 'ssss_approx'}

Expand Down Expand Up @@ -605,11 +614,34 @@ def set_vkscreen(opt, name):
direct_scf_tol=self.direct_scf_tol)
opt_ssll.q_cond = numpy.array([opt_llll.q_cond, opt_ssss.q_cond])
set_vkscreen(opt_ssll, 'CVHFrkbssll_vkscreen')
logger.timer(self, 'init_direct_scf_coulomb', *cpu0)

opt_gaunt_lsls = None
opt_gaunt_lssl = None

#TODO: prescreen for gaunt
opt_gaunt = None
logger.timer(self, 'init_direct_scf', *cpu0)
return opt_llll, opt_ssll, opt_ssss, opt_gaunt
if self.with_gaunt:
if self.with_breit:
# integral function int2e_breit_ssp1ssp2_spinor evaluates
# -1/2[alpha1*alpha2/r12 + (alpha1*r12)(alpha2*r12)/r12^3]
intor_prefix = 'int2e_breit_'
else:
# integral function int2e_ssp1ssp2_spinor evaluates only
# alpha1*alpha2/r12. Minus sign was not included.
intor_prefix = 'int2e_'
opt_gaunt_lsls = _VHFOpt(mol, intor_prefix + 'ssp1ssp2_spinor',
'CVHFrkb_gaunt_lsls_prescreen', 'CVHFrkb_asym_q_cond',
'CVHFrkb_dm_cond',
direct_scf_tol=self.direct_scf_tol/c1**2)

opt_gaunt_lssl = _VHFOpt(mol, intor_prefix + 'ssp1sps2_spinor',
'CVHFrkb_gaunt_lssl_prescreen', 'CVHFrkb_asym_q_cond',
'CVHFrkb_dm_cond',
direct_scf_tol=self.direct_scf_tol/c1**2)

logger.timer(self, 'init_direct_scf_gaunt_breit', *cpu0)
#return None, None, None, None, None
return opt_llll, opt_ssll, opt_ssss, opt_gaunt_lsls, opt_gaunt_lssl

def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True,
omega=None):
Expand All @@ -622,13 +654,16 @@ def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True,
self._opt[omega] = self.init_direct_scf(mol)
vhfopt = self._opt.get(omega)
if vhfopt is None:
opt_llll = opt_ssll = opt_ssss = opt_gaunt = None
opt_llll = opt_ssll = opt_ssss = opt_gaunt_lsls = opt_gaunt_lssl = None
else:
opt_llll, opt_ssll, opt_ssss, opt_gaunt = vhfopt
opt_llll, opt_ssll, opt_ssss, opt_gaunt_lsls, opt_gaunt_lssl = vhfopt
if self.screening is False:
opt_llll = opt_ssll = opt_ssss = opt_gaunt_lsls = opt_gaunt_lssl = None

opt_gaunt = (opt_gaunt_lsls, opt_gaunt_lssl)
vj, vk = get_jk_coulomb(mol, dm, hermi, self._coulomb_level,
opt_llll, opt_ssll, opt_ssss, omega, log)

t1 = log.timer_debug1('Coulomb', *t0)
if self.with_breit:
assert omega is None
if ('SSSS' in self._coulomb_level.upper() or
Expand All @@ -644,6 +679,7 @@ def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True,
vj1, vk1 = _call_veff_gaunt_breit(mol, dm, hermi, opt_gaunt, False)
vj += vj1
vk += vk1
log.timer_debug1('Gaunt and Breit term', *t1)

log.timer('vj and vk', *t0)
return vj, vk
Expand Down Expand Up @@ -927,6 +963,14 @@ def _call_veff_ssss(mol, dm, hermi=1, mf_opt=None):
return _jk_triu_(mol, vj, vk, hermi)

def _call_veff_gaunt_breit(mol, dm, hermi=1, mf_opt=None, with_breit=False):
if mf_opt is not None:
opt_gaunt_lsls, opt_gaunt_lssl = mf_opt
else:
opt_gaunt_lsls = opt_gaunt_lssl = None

log = logger.new_logger(mol)
t0 = (logger.process_clock(), logger.perf_counter())

if with_breit:
# integral function int2e_breit_ssp1ssp2_spinor evaluates
# -1/2[alpha1*alpha2/r12 + (alpha1*r12)(alpha2*r12)/r12^3]
Expand All @@ -946,30 +990,32 @@ def _call_veff_gaunt_breit(mol, dm, hermi=1, mf_opt=None, with_breit=False):
dmsl = dm[n2c:,:n2c].copy()
dmll = dm[:n2c,:n2c].copy()
dmss = dm[n2c:,n2c:].copy()
dms = [dmsl, dmsl, dmls, dmll, dmss]
dms = [dmsl, dmls, dmll, dmss]
else:
n_dm = len(dm)
n2c = dm[0].shape[0] // 2
dmll = [dmi[:n2c,:n2c].copy() for dmi in dm]
dmls = [dmi[:n2c,n2c:].copy() for dmi in dm]
dmsl = [dmi[n2c:,:n2c].copy() for dmi in dm]
dmss = [dmi[n2c:,n2c:].copy() for dmi in dm]
dms = dmsl + dmsl + dmls + dmll + dmss
dms = dmsl + dmls + dmll + dmss
vj = numpy.zeros((n_dm,n2c*2,n2c*2), dtype=numpy.complex128)
vk = numpy.zeros((n_dm,n2c*2,n2c*2), dtype=numpy.complex128)

jks = ('lk->s1ij',) * n_dm \
+ ('jk->s1il',) * n_dm
vx = _vhf.rdirect_bindm(intor_prefix+'ssp1ssp2_spinor', 's1', jks, dms[:n_dm*2], 1,
mol._atm, mol._bas, mol._env, mf_opt)
vj[:,:n2c,n2c:] = vx[:n_dm,:,:]
vk[:,:n2c,n2c:] = vx[n_dm:,:,:]
jks = ('lk->s1ij', 'jk->s1il')
vj_ls, vk_ls = _vhf.rdirect_mapdm(intor_prefix+'ssp1ssp2_spinor', 's1', jks, dms[:n_dm], 1,
mol._atm, mol._bas, mol._env, opt_gaunt_lsls)
vj[:,:n2c,n2c:] = vj_ls
vk[:,:n2c,n2c:] = vk_ls
t0 = log.timer_debug1('LSLS contribution', *t0)

jks = ('lk->s1ij',) * n_dm \
+ ('li->s1kj',) * n_dm \
+ ('jk->s1il',) * n_dm
vx = _vhf.rdirect_bindm(intor_prefix+'ssp1sps2_spinor', 's1', jks, dms[n_dm*2:], 1,
mol._atm, mol._bas, mol._env, mf_opt)
vx = _vhf.rdirect_bindm(intor_prefix+'ssp1sps2_spinor', 's1', jks, dms[n_dm:], 1,
mol._atm, mol._bas, mol._env, opt_gaunt_lssl)

t0 = log.timer_debug1('LSSL contribution', *t0)
vj[:,:n2c,n2c:]+= vx[ :n_dm ,:,:]
vk[:,n2c:,n2c:] = vx[n_dm :n_dm*2,:,:]
vk[:,:n2c,:n2c] = vx[n_dm*2: ,:,:]
Expand Down Expand Up @@ -1045,8 +1091,8 @@ def set_dm(self, dm, atm, bas, env):
(1, 0, (1, 1)), ]}
mol.build()

##############
# SCF result
##############
# SCF result
method = UHF(mol)
energy = method.scf() #-2.38146942868
print(energy)
Expand Down
1 change: 1 addition & 0 deletions pyscf/scf/test/test_dhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_get_jk_with_gaunt_breit_high_cost(self):
vj0 = numpy.einsum('ijkl,xlk->xij', eri1, dm)
vk0 = numpy.einsum('ijkl,xjk->xil', eri1, dm)

mf = scf.dhf.DHF(h4)
mf.with_breit = True
vj1, vk1 = mf.get_jk(h4, dm, hermi=1)
self.assertTrue(numpy.allclose(vj0, vj1))
Expand Down

0 comments on commit 0ada19a

Please sign in to comment.