Skip to content

Commit

Permalink
Adjust test FP precision
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Sep 2, 2016
1 parent b8843be commit d081a6e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
30 changes: 21 additions & 9 deletions testsuite/test_3c2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@
_cint.CINTlen_spinor.restype = ctypes.c_int


def close(v1, vref, count, place):
return round(abs(v1-vref)/count, place) == 0

def test_int3c2e_sph(name, fnref, vref, dim, place):
intor = getattr(_cint, name)
Expand All @@ -178,6 +180,7 @@ def test_int3c2e_sph(name, fnref, vref, dim, place):
opref = numpy.empty(1000000*dim)
pref = opref.ctypes.data_as(ctypes.c_void_p)
v1 = 0
cnt = 0
for k in range(nbas.value):
l = nfitid
bas[l,ATOM_OF] = bas[k,ATOM_OF]
Expand All @@ -193,10 +196,11 @@ def test_int3c2e_sph(name, fnref, vref, dim, place):
if not numpy.allclose(opref[:nd], op[:nd]):
print 'Fail:', name, i,j,k
v1 += abs(numpy.array(op[:nd])).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += nd
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref


def sf2spinor(mat, i, j, bas):
Expand Down Expand Up @@ -230,6 +234,7 @@ def test_int3c2e_spinor(name, fnref, vref, dim, place):
intoref = getattr(_cint, fnref)
intor.restype = ctypes.c_void_p
v1 = 0
cnt = 0
for k in range(nbas.value):
l = nfitid
for j in range(nbas.value):
Expand All @@ -252,10 +257,11 @@ def test_int3c2e_spinor(name, fnref, vref, dim, place):
if not numpy.allclose(zmat, op[:,:,:,0]):
print 'Fail:', name, i,j,k
v1 += abs(numpy.array(op)).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += op.size
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref


def test_int2c_sph(name, fnref, vref, dim, place):
Expand All @@ -267,6 +273,7 @@ def test_int2c_sph(name, fnref, vref, dim, place):
opref = numpy.empty(1000000*dim)
pref = opref.ctypes.data_as(ctypes.c_void_p)
v1 = 0
cnt = 0
for k in range(nbas.value):
for i in range(nbas.value):
j = nfitid1
Expand All @@ -281,14 +288,19 @@ def test_int2c_sph(name, fnref, vref, dim, place):
if not numpy.allclose(opref[:nd], op[:nd]):
print 'Fail:', name, i,k
v1 += abs(numpy.array(op[:nd])).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += nd
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref



if __name__ == "__main__":
if "--high-prec" in sys.argv:
def close(v1, vref, count, place):
return round(abs(v1-vref), place) == 0

for f in (('cint3c2e_sph', 'cint2e_sph', 1586.350797432699, 1, 10),
('cint3c2e_ip1_sph', 'cint2e_ip1_sph', 2242.052249267909, 3, 10),
('cint3c2e_ip2_sph', 'cint2e_ip2_sph', 1970.982483860059, 3, 10),
Expand Down
46 changes: 30 additions & 16 deletions testsuite/test_cint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,27 @@
_cint.CINTlen_spinor.restype = ctypes.c_int


def close(v1, vref, count, place):
return round(abs(v1-vref)/count, place) == 0

def test_int1e_sph(name, vref, dim, place):
intor = getattr(_cint, name)
intor.restype = ctypes.c_void_p
op = (ctypes.c_double * (10000 * dim))()
v1 = 0
cnt = 0
for j in range(nbas.value*2):
for i in range(j+1):
di = (bas[i,ANG_OF] * 2 + 1) * bas[i,NCTR_OF]
dj = (bas[j,ANG_OF] * 2 + 1) * bas[j,NCTR_OF]
shls = (ctypes.c_int * 2)(i, j)
intor(op, shls, c_atm, natm, c_bas, nbas, c_env);
v1 += abs(numpy.array(op[:di*dj*dim])).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += di*dj*dim
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref

def cdouble_to_cmplx(arr):
return numpy.array(arr)[0::2] + numpy.array(arr)[1::2] * 1j
Expand All @@ -168,17 +172,19 @@ def test_int1e_spinor(name, vref, dim, place):
intor.restype = ctypes.c_void_p
op = (ctypes.c_double * (20000 * dim))()
v1 = 0
cnt = 0
for j in range(nbas.value*2):
for i in range(j+1):
di = _cint.CINTlen_spinor(i, c_bas, nbas) * bas[i,NCTR_OF]
dj = _cint.CINTlen_spinor(j, c_bas, nbas) * bas[j,NCTR_OF]
shls = (ctypes.c_int * 2)(i, j)
intor(op, shls, c_atm, natm, c_bas, nbas, c_env);
v1 += abs(cdouble_to_cmplx(op[:di*dj*dim*2])).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += di*dj*dim*2
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref

def max_loc(arr):
loc = []
Expand Down Expand Up @@ -228,6 +234,7 @@ def test_int2e_sph(name, vref, dim, place):
intor.restype = ctypes.c_void_p
op = (ctypes.c_double * (1000000 * dim))()
v1 = 0
cnt = 0
for l in range(nbas.value*2):
for k in range(l+1):
for j in range(nbas.value*2):
Expand All @@ -239,16 +246,18 @@ def test_int2e_sph(name, vref, dim, place):
shls = (ctypes.c_int * 4)(i, j, k, l)
intor(op, shls, c_atm, natm, c_bas, nbas, c_env, opt);
v1 += abs(numpy.array(op[:di*dj*dk*dl*dim])).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += di*dj*dk*dl*dim
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref

def test_int2e_spinor(name, vref, dim, place):
intor = getattr(_cint, name)
intor.restype = ctypes.c_void_p
op = (ctypes.c_double * (2000000 * dim))()
v1 = 0
cnt = 0
for l in range(nbas.value*2):
for k in range(l+1):
for j in range(nbas.value*2):
Expand All @@ -260,10 +269,11 @@ def test_int2e_spinor(name, vref, dim, place):
shls = (ctypes.c_int * 4)(i, j, k, l)
intor(op, shls, c_atm, natm, c_bas, nbas, c_env, opt);
v1 += abs(cdouble_to_cmplx(op[:di*dj*dk*dl*dim*2])).sum()
if round(abs(v1-vref), place):
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref
else:
cnt += di*dj*dk*dl*dim*2
if close(v1, vref, cnt, place):
print "pass: ", name
else:
print "* FAIL: ", name, ". err:", '%.16g' % abs(v1-vref), "/", vref

def test_comp2e_spinor(name1, name_ref, shift, dim, place):
intor = getattr(_cint, name1)
Expand Down Expand Up @@ -308,6 +318,10 @@ def test_comp2e_spinor(name1, name_ref, shift, dim, place):


if __name__ == "__main__":
if "--high-prec" in sys.argv:
def close(v1, vref, count, place):
return round(abs(v1-vref), place) == 0

for f in (('cint1e_ovlp_sph' , 320.9470780962389, 1, 11),
('cint1e_nuc_sph' , 3664.898206036863, 1, 10),
('cint1e_kin_sph' , 887.2525599069498, 1, 11),
Expand Down Expand Up @@ -362,14 +376,14 @@ def test_comp2e_spinor(name1, name_ref, shift, dim, place):
):
test_int1e_spinor(*f)

for f in (('cint2e_sph' , 56243.88328768107 , 1, 9 ),
for f in (('cint2e_sph' , 56243.88328768107 , 1, 8 ),
('cint2e_ig1_sph', 8101.087334398195 , 3, 10),
('cint2e_ip1_sph', 115489.8643866550 , 3, 8 ),
('cint2e_p1vxp1_sph', 89014.88169743448, 3, 9),
):
test_int2e_sph(*f)
if "--quick" not in sys.argv:
for f in (('cint2e' , 37737.11365710611, 1, 9),
for f in (('cint2e' , 37737.11365710611, 1, 8),
('cint2e_spsp1' , 221528.4764668166, 1, 8),
('cint2e_spsp1spsp2' , 1391716.876869147, 1, 7),
('cint2e_srsr1' , 178572.7398308939, 1, 8),
Expand All @@ -388,7 +402,7 @@ def test_comp2e_spinor(name1, name_ref, shift, dim, place):
('cint2e_ipspsp1spsp2', 1443972.936563201, 3, 7 ),
):
test_int2e_spinor(*f)

test_comp2e_spinor('cint2e_spsp1', 'cint2e', (4,4,0,0), 1, 11)
test_comp2e_spinor('cint2e_spsp1spsp2', 'cint2e', (4,4,4,4), 1, 11)
test_comp2e_spinor('cint2e_spsp1spsp2', 'cint2e_spsp1', (0,0,4,4), 1, 11)
Expand Down Expand Up @@ -423,5 +437,5 @@ def test_comp2e_spinor(name1, name_ref, shift, dim, place):
v1 = abs(opz-opr[:,:,2]).sum()
v1 += abs(opzz-oprr[:,:,8]).sum()
v1 += abs(opr2-oprr[:,:,0]-oprr[:,:,4]-oprr[:,:,8]).sum()
if round(v1, 13):
if round(v1/(di*dj), 13):
print "* FAIL: ", i, j, v1

0 comments on commit d081a6e

Please sign in to comment.