Skip to content

Commit

Permalink
Merge branch 'dev' into issue_54
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Mar 3, 2025
2 parents 3e01567 + 712b005 commit bcd7cb9
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 32 deletions.
1 change: 1 addition & 0 deletions my_pyscf/lassi/op_o1/hci.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=Non
# Handling for 1s1c: need to do both a'.sm.b and b'.sp.a explicitly
all_interactions_full_square = True
interaction_has_spin = ('_1c_', '_1c1d_', '_1s1c_', '_2c_')
ltri_ambiguous = False

def _init_vecs (self):
hci_fr_pabq = []
Expand Down
17 changes: 6 additions & 11 deletions my_pyscf/lassi/op_o1/hsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=Non
HamS2Ovlp.__init__(self, ints, nlas, hopping_index, lroots, h1, h2,
mask_bra_space=mask_bra_space, mask_ket_space=mask_ket_space,
log=log, max_memory=max_memory, dtype=dtype)
self.urootstr = np.asarray ([[i.unique_root[r] for i in self.ints]
for r in range (self.nroots)]).T
self.x = self.si = np.zeros (self.nstates, self.dtype)
self.ox = np.zeros (self.nstates, self.dtype)
self.ox1 = np.zeros (self.nstates, self.dtype)
Expand Down Expand Up @@ -202,12 +200,12 @@ def _opuniq_x_group_(self, inv, group):
t0, w0 = logger.process_clock (), logger.perf_counter ()
kets = np.unique (ovlplink[:,0])
ketvecs = {ket: self.get_xvec (ket, *inv) for ket in set (kets)}
ovecs = {tuple(row):0 for row in ovlplink}
ovecs = {tuple(row[1:]):0 for row in ovlplink}
for row in ovlplink:
xvec = ketvecs[row[0]]
midstr = row[1:]
ketstr = self.urootstr[:,row[0]]
ovecs[tuple(row)] += self.ox_ovlp_part (midstr, ketstr, xvec, inv)
ovecs[tuple(midstr)] += self.ox_ovlp_part (midstr, ketstr, xvec, inv)
t1, w1 = logger.process_clock (), logger.perf_counter ()
self.dt_sX += (t1-t0)
self.dw_sX += (w1-w0)
Expand All @@ -228,22 +226,19 @@ def _opuniq_x_group_(self, inv, group):
self.dt_pX += (t3-t2)
self.dw_pX += (w3-w2)

def _opuniq_x_(self, op, bra, ket, ovecs, *inv):
def _opuniq_x_(self, op, obra, oket, ovecs, *inv):
'''All operations which are unique in that a given set of fragment bra statelets are
coupled to a given set of fragment ket statelets'''
key = tuple ((bra, ket)) + inv
key = tuple ((obra, oket)) + inv
inv = list (set (inv))
brakets, bras, braHs = self.get_nonuniq_exc_square (key)
bravecs = {bra: 0.0 for bra in bras+braHs}
for bra, ket in brakets:
bravecs[bra] += ovecs[tuple ((ket,)) + self.ox_ovlp_urootstr (bra, ket, inv)]
for bra in bras:
vec = bravecs[bra]
vec = ovecs[self.ox_ovlp_urootstr (bra, oket, inv)]
self.put_ox1_(lib.dot (op, vec.T).ravel (), bra, *inv)
if len (braHs):
op = op.conj ().T
for bra in braHs:
vec = bravecs[bra]
vec = ovecs[self.ox_ovlp_urootstr (bra, obra, inv)]
self.put_ox1_(lib.dot (op, vec.T).ravel (), bra, *inv)
return

Expand Down
33 changes: 24 additions & 9 deletions my_pyscf/lassi/op_o1/stdm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pyscf import lib
from pyscf.lib import logger
from itertools import product
from itertools import product, combinations
from mrh.my_pyscf.lassi.citools import get_rootaddr_fragaddr, umat_dot_1frag_
from mrh.my_pyscf.lassi.op_o1 import frag
from mrh.my_pyscf.lassi.op_o1.utilities import *
Expand Down Expand Up @@ -107,6 +107,9 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_
for nelec_sf in self.nelec_rf]
self.nelec_rf = self.nelec_rf.sum (1)

self.urootstr = np.asarray ([[i.unique_root[r] for i in self.ints]
for r in range (self.nroots)]).T

exc = self.make_exc_tables (hopping_index)
self.nonuniq_exc = {}
self.exc_null = self.mask_exc_table_(exc['null'], 'null', mask_bra_space, mask_ket_space)
Expand Down Expand Up @@ -136,6 +139,14 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_
else:
raise NotImplementedError (self.dtype)

def interaction_fprint (self, bra, ket, frags, ltri=False):
frags = np.sort (frags)
brastr = self.urootstr[frags,bra]
ketstr = self.urootstr[frags,ket]
if ltri: brastr, ketstr = sorted ([list(brastr),list(ketstr)])
fprint = np.stack ([frags, brastr, ketstr], axis=0)
return fprint

def init_profiling (self):
self.dt_1d, self.dw_1d = 0.0, 0.0
self.dt_2d, self.dw_2d = 0.0, 0.0
Expand Down Expand Up @@ -326,6 +337,7 @@ def make_exc_tables (self, hopping_index):

all_interactions_full_square = False
interaction_has_spin = ('_1c_', '_1c1d_', '_2c_')
ltri_ambiguous = True

def mask_exc_table_(self, exc, lbl, mask_bra_space=None, mask_ket_space=None):
# Part 1: restrict to the caller-specified rectangle
Expand All @@ -339,28 +351,31 @@ def mask_exc_table_(self, exc, lbl, mask_bra_space=None, mask_ket_space=None):
if lbl=='null': return exc
ulblu = '_' + lbl + '_'
excp = exc[:,:-1] if ulblu in self.interaction_has_spin else exc
fprintLT = []
fprint = []
for row in excp:
frow = []
bra, ket = row[:2]
frags = row[2:]
for frag in frags:
intf = self.ints[frag]
frow.extend ([frag, intf.unique_root[bra], intf.unique_root[ket]])
fprint.append (frow)
fpLT = self.interaction_fprint (bra, ket, frags, ltri=self.ltri_ambiguous)
fprintLT.append (fpLT.ravel ())
fp = self.interaction_fprint (bra, ket, frags, ltri=False)
fprint.append (fp.ravel ())
fprintLT = np.asarray (fprintLT)
fprint = np.asarray (fprint)
nexc = len (exc)
_, idx, inv = np.unique (fprint, axis=0, return_index=True, return_inverse=True)
fprintLT, idx, inv = np.unique (fprintLT, axis=0, return_index=True, return_inverse=True)
# for some reason this squeeze is necessary for some versions of numpy; however...
eqmap = np.squeeze (idx[inv])
for uniq_idx in idx:
for fpLT, uniq_idx in zip (fprintLT, idx):
row_uniq = excp[uniq_idx]
# ...numpy.where (0==0) triggers a DeprecationWarning, so I have to atleast_1d it
uniq_idxs = np.where (np.atleast_1d (eqmap==uniq_idx))[0]
braket_images = exc[np.ix_(uniq_idxs,[0,1])]
iT = np.any (fprint[uniq_idx][None,:]!=fprint[uniq_idxs], axis=1)
braket_images[iT,:] = braket_images[iT,::-1]
self.nonuniq_exc[tuple(row_uniq)] = braket_images
exc = exc[idx]
nuniq = len (idx)
nuniq = len (exc)
self.log.debug ('%d/%d unique interactions of %s type',
nuniq, nexc, lbl)
return exc
Expand Down
2 changes: 1 addition & 1 deletion pyscf-forge_version.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
git+https://github.com/pyscf/pyscf-forge.git@b7454c8a8c8d2ad13a03e93dc70b538e38d1bf8b
git+https://github.com/pyscf/pyscf-forge.git@226601b79600961a89b034f2c541861901200b71

2 changes: 1 addition & 1 deletion pyscf_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
git+https://github.com/pyscf/pyscf.git@afdd09fcf08c32a6de05f4ec827e8d1602531707
git+https://github.com/pyscf/pyscf.git@bee0ce288a655105e27fcb0293b203939b7aecc9
22 changes: 12 additions & 10 deletions tests/lassi/test_4frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,12 @@ def test_lassis_1111 (self):
lsi.prepare_states_()
h0, h1, h2 = ham_2q (las1, las1.mo_coeff)
case_contract_op_si (self, las1, h1, h2, lsi.ci, lsi.get_nelec_frs ())
lsi.kernel ()
self.assertTrue (lsi.converged)
self.assertAlmostEqual (lsi.e_roots[0], -1.867291372401379, 6)
case_lassis_fbf_2_model_state (self, lsi)
case_lassis_fbfdm (self, lsi)
else:
lsi.kernel ()
self.assertTrue (lsi.converged)
self.assertAlmostEqual (lsi.e_roots[0], -1.867291372401379, 6)
case_lassis_fbf_2_model_state (self, lsi)
case_lassis_fbfdm (self, lsi)

def test_lassis_slow (self):
las0 = las.get_single_state_las (state=0)
Expand All @@ -192,11 +193,12 @@ def test_lassis_slow (self):
lsi.prepare_states_()
h0, h1, h2 = ham_2q (las0, las0.mo_coeff)
case_contract_op_si (self, las, h1, h2, lsi.ci, lsi.get_nelec_frs ())
lsi.kernel ()
self.assertTrue (lsi.converged)
self.assertAlmostEqual (lsi.e_roots[0], -304.5372586630968, 3)
case_lassis_fbf_2_model_state (self, lsi)
#case_lassis_fbfdm (self, lsi)
else:
lsi.kernel ()
self.assertTrue (lsi.converged)
self.assertAlmostEqual (lsi.e_roots[0], -304.5372586630968, 3)
case_lassis_fbf_2_model_state (self, lsi)
#case_lassis_fbfdm (self, lsi)

def test_fdm1 (self):
make_fdm1 = get_fdm1_maker (las, las.ci, nelec_frs, si)
Expand Down

0 comments on commit bcd7cb9

Please sign in to comment.