Skip to content

Commit

Permalink
lassi hsi factorize
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Feb 21, 2025
1 parent 202d5d0 commit 09f6e67
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 53 deletions.
8 changes: 4 additions & 4 deletions debug/lassi/debug_22.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from mrh.my_pyscf.lassi.op_o1 import get_fdm1_maker
from mrh.my_pyscf.lassi.sitools import make_sdm1
from mrh.tests.lassi.addons import case_contract_hlas_ci, case_lassis_fbf_2_model_state
from mrh.tests.lassi.addons import case_lassis_fbfdm, case_contract_op_si
from mrh.tests.lassi.addons import debug_contract_op_si
from mrh.tests.lassi.addons import case_lassis_fbfdm, case_contract_op_si, debug_contract_op_si

def setUpModule ():
global mol, mf, lsi, las, mc, op
Expand Down Expand Up @@ -128,8 +127,9 @@ class KnownValues(unittest.TestCase):
def test_contract_op_si (self):
e_roots, si, las = lsi.e_roots, lsi.si, lsi._las
h0, h1, h2 = lsi.ham_2q ()
h1[:] = h2[:2,:2,:2,:2] = h2[2:,2:,2:,2:] = 0.0
#h2[:] = 0.0
h2[:] = 0
#h2[:2,:2,:2,:2] = 0
#h2[2:,2:,2:,2:] = 0
debug_contract_op_si (self, las, h1, h2, las.ci, lsi.get_nelec_frs ())

#def test_lassis (self):
Expand Down
120 changes: 71 additions & 49 deletions my_pyscf/lassi/op_o1/hsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=Non
HamS2Ovlp.__init__(self, ints, nlas, hopping_index, lroots, h1, h2,
mask_bra_space=mask_bra_space, mask_ket_space=mask_ket_space,
log=log, max_memory=max_memory, dtype=dtype)
self.rootsigs = np.asarray ([[i.unique_root[r] for i in self.ints]
for r in range (self.nroots)])
self.urootstr = np.asarray ([[i.unique_root[r] for i in self.ints]
for r in range (self.nroots)]).T
self.x = self.si = np.zeros (self.nstates, self.dtype)
self.ox = np.zeros (self.nstates, self.dtype)
self.ox1 = np.zeros (self.nstates, self.dtype)
Expand All @@ -49,6 +49,8 @@ def _cache_operatorpart_(self):
(self._crunch_1c_, self._crunch_1c1d_, self._crunch_1s1c_,
self._crunch_2c_)):
self._crunch_oppart_(exc, fn, has_s=False)
self.excgroups_s = self._prepare_urootstr (self.excgroups_s)
self.excgroups_h = self._prepare_urootstr (self.excgroups_h)
self.log.debug1 (self.sprint_cache_profile ())
self.log.timer_debug1 ('HamS2OvlpOperators operator cacheing', *t0)

Expand Down Expand Up @@ -86,6 +88,30 @@ def _crunch_oppart_(self, exc, fn, has_s=False):
val.append ([op, bra, ket, row])
self.excgroups_s[key] = val

def _prepare_urootstr (self, groups):
for inv, group in groups.items ():
tab = np.zeros ((0,2), dtype=int)
for op, bra, ket, myinv in group:
key = tuple ((bra, ket)) + tuple (myinv)
tab = np.append (tab, self.get_nonuniq_exc_square (key)[0], axis=0)
tab = np.unique (tab, axis=0)
ovlplinkstr = [[ket,] + list (self.ox_ovlp_urootstr (bra, ket, inv)) for bra, ket in tab]
ovlplinkstr = np.unique (ovlplinkstr, axis=0)
groups[inv] = (group, np.asarray (ovlplinkstr))
return groups

def get_nonuniq_exc_square (self, key):
tab = self.nonuniq_exc[key]
n0 = len (tab)
tab = np.append (tab, tab[:,::-1], axis=0)
_, idx_uniq = np.unique (tab, return_index=True, axis=0)
tab = np.append (tab, np.ones ((len(tab),1),dtype=int), axis=1)
tab[:n0,2] = 0
brakets = tab[idx_uniq,:2]
_, idx_uniq = np.unique (tab[:,0], return_index=True, axis=0)
braopids = tab[:,[0,2]][idx_uniq,:]
return brakets, braopids

def init_cache_profiling (self):
self.dt_1d, self.dw_1d = 0.0, 0.0
self.dt_2d, self.dw_2d = 0.0, 0.0
Expand Down Expand Up @@ -173,70 +199,66 @@ def _s2_op (self, x):

def _opuniq_x_group_(self, inv, group):
'''All unique operations which have a set of nonspectator fragments in common'''
self.ox1[:] = 0
for op, bra, ket, myinv in group:
self._opuniq_x_(op, bra, ket, *myinv)
oplink, ovlplink = group

t0, w0 = logger.process_clock (), logger.perf_counter ()
kets = np.unique (ovlplink[:,0])
ketvecs = {ket: self.get_xvec (ket, *inv) for ket in set (kets)}
ovecs = {tuple(row):0 for row in ovlplink}
for row in ovlplink:
xvec = ketvecs[row[0]]
midstr = row[1:]
ketstr = self.urootstr[:,row[0]]
ovecs[tuple(row)] += self.ox_ovlp_part (midstr, ketstr, xvec, inv)
t1, w1 = logger.process_clock (), logger.perf_counter ()
self.dt_sX += (t1-t0)
self.dw_sX += (w1-w0)

self.ox1[:] = 0
for op, bra, ket, myinv in oplink:
self._opuniq_x_(op, bra, ket, ovecs, *myinv)
t2, w2 = logger.process_clock (), logger.perf_counter ()
self.dt_oX += (t2-t1)
self.dw_oX += (w2-w1)

for bra in range (self.nroots):
i, j = self.offs_lroots[bra]
self.ox[i:j] += transpose_sivec_with_slow_fragments (
self.ox1[i:j], self.lroots[:,bra], *inv
)[0]
t1, w1 = logger.process_clock (), logger.perf_counter ()
self.dt_pX += (t1-t0)
self.dw_pX += (w1-w0)
t3, w3 = logger.process_clock (), logger.perf_counter ()
self.dt_pX += (t3-t2)
self.dw_pX += (w3-w2)

def _opuniq_x_(self, op, bra, ket, *inv):
def _opuniq_x_(self, op, bra, ket, ovecs, *inv):
'''All operations which are unique in that a given set of fragment bra statelets are
coupled to a given set of fragment ket statelets'''
key = tuple ((bra, ket)) + inv
inv = list (set (inv))
tab = self.nonuniq_exc[key]
bras, kets = self.nonuniq_exc[key].T
self._op_x_(bras, kets, op, inv)
idx = bras==kets
bras = bras[~idx]
kets = kets[~idx]
self._op_x_(kets, bras, op.conj ().T, inv)
brakets, braopids = self.get_nonuniq_exc_square (key)
bravecs = {bra: 0.0 for bra in braopids[:,0]}
for bra, ket in brakets:
bravecs[bra] += ovecs[tuple ((ket,)) + self.ox_ovlp_urootstr (bra, ket, inv)]
op = (op, op.conj ().T)
for bra, opid in braopids:
vec = bravecs[bra]
self.put_oxvec_(lib.dot (op[opid], vec.T).ravel (), bra, *inv)
return

def _op_x_(self, bras, kets, op, inv):
t0, w0 = logger.process_clock (), logger.perf_counter ()

ketvecs = {ket: self.get_xvec (ket, *inv) for ket in set (kets)}

t1, w1 = logger.process_clock (), logger.perf_counter ()
self.dt_gX += (t1-t0)
self.dw_gX += (w1-w0)

bravecs = {bra: 0.0 for bra in set (bras)}
for bra, ket in zip (bras, kets):
bravecs[bra] += self.ox_ovlp_part (bra, ket, ketvecs[ket], inv)

t2, w2 = logger.process_clock (), logger.perf_counter ()
self.dt_sX += (t2-t1)
self.dw_sX += (w2-w1)

bravecs = {bra: lib.dot (op, vec.T) for bra, vec in bravecs.items ()}

t3, w3 = logger.process_clock (), logger.perf_counter ()
self.dt_oX += (t3-t2)
self.dw_oX += (w3-w2)

for bra, vec in bravecs.items ():
self.put_oxvec_(vec.ravel (), bra, *inv)

t4, w4 = logger.process_clock (), logger.perf_counter ()
self.dt_pX += (t4-t3)
self.dw_pX += (w4-w3)
return
def ox_ovlp_urootstr (self, bra, ket, inv):
'''Find the urootstr corresponding to the action of the overlap part of an operator
from ket to bra, which might or might not be a part of the model space.'''
urootstr = self.urootstr[:,bra].copy ()
inv = list (inv)
urootstr[inv] = self.urootstr[inv,ket]
return tuple (urootstr)

def ox_ovlp_part (self, bra, ket, vec, inv):
#if (bra==ket): vec = vec * 0.5
def ox_ovlp_part (self, brastr, ketstr, vec, inv):
spec = np.ones (self.nfrags, dtype=bool)
spec[inv] = False
spec[list(inv)] = False
lr = 1
for i in range (self.nfrags):
for i, (bra, ket) in enumerate (zip (brastr, ketstr)):
lket = self.lroots[i,ket]
if spec[i]:
vec = vec.reshape (-1,lket,lr)
Expand Down

0 comments on commit 09f6e67

Please sign in to comment.