diff --git a/my_pyscf/lassi/op_o1/hci.py b/my_pyscf/lassi/op_o1/hci.py index b04558e8..34809452 100644 --- a/my_pyscf/lassi/op_o1/hci.py +++ b/my_pyscf/lassi/op_o1/hci.py @@ -34,6 +34,7 @@ def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=Non # Handling for 1s1c: need to do both a'.sm.b and b'.sp.a explicitly all_interactions_full_square = True interaction_has_spin = ('_1c_', '_1c1d_', '_1s1c_', '_2c_') + ltri_ambiguous = False def _init_vecs (self): hci_fr_pabq = [] diff --git a/my_pyscf/lassi/op_o1/hsi.py b/my_pyscf/lassi/op_o1/hsi.py index 859e78d1..522a0627 100644 --- a/my_pyscf/lassi/op_o1/hsi.py +++ b/my_pyscf/lassi/op_o1/hsi.py @@ -29,8 +29,6 @@ def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=Non HamS2Ovlp.__init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=mask_bra_space, mask_ket_space=mask_ket_space, log=log, max_memory=max_memory, dtype=dtype) - self.urootstr = np.asarray ([[i.unique_root[r] for i in self.ints] - for r in range (self.nroots)]).T self.x = self.si = np.zeros (self.nstates, self.dtype) self.ox = np.zeros (self.nstates, self.dtype) self.ox1 = np.zeros (self.nstates, self.dtype) @@ -202,12 +200,12 @@ def _opuniq_x_group_(self, inv, group): t0, w0 = logger.process_clock (), logger.perf_counter () kets = np.unique (ovlplink[:,0]) ketvecs = {ket: self.get_xvec (ket, *inv) for ket in set (kets)} - ovecs = {tuple(row):0 for row in ovlplink} + ovecs = {tuple(row[1:]):0 for row in ovlplink} for row in ovlplink: xvec = ketvecs[row[0]] midstr = row[1:] ketstr = self.urootstr[:,row[0]] - ovecs[tuple(row)] += self.ox_ovlp_part (midstr, ketstr, xvec, inv) + ovecs[tuple(midstr)] += self.ox_ovlp_part (midstr, ketstr, xvec, inv) t1, w1 = logger.process_clock (), logger.perf_counter () self.dt_sX += (t1-t0) self.dw_sX += (w1-w0) @@ -228,22 +226,19 @@ def _opuniq_x_group_(self, inv, group): self.dt_pX += (t3-t2) self.dw_pX += (w3-w2) - def _opuniq_x_(self, op, bra, ket, ovecs, *inv): + def _opuniq_x_(self, op, obra, oket, ovecs, *inv): '''All operations which are unique in that a given set of fragment bra statelets are coupled to a given set of fragment ket statelets''' - key = tuple ((bra, ket)) + inv + key = tuple ((obra, oket)) + inv inv = list (set (inv)) brakets, bras, braHs = self.get_nonuniq_exc_square (key) - bravecs = {bra: 0.0 for bra in bras+braHs} - for bra, ket in brakets: - bravecs[bra] += ovecs[tuple ((ket,)) + self.ox_ovlp_urootstr (bra, ket, inv)] for bra in bras: - vec = bravecs[bra] + vec = ovecs[self.ox_ovlp_urootstr (bra, oket, inv)] self.put_ox1_(lib.dot (op, vec.T).ravel (), bra, *inv) if len (braHs): op = op.conj ().T for bra in braHs: - vec = bravecs[bra] + vec = ovecs[self.ox_ovlp_urootstr (bra, obra, inv)] self.put_ox1_(lib.dot (op, vec.T).ravel (), bra, *inv) return diff --git a/my_pyscf/lassi/op_o1/stdm.py b/my_pyscf/lassi/op_o1/stdm.py index eda3b598..4c821932 100644 --- a/my_pyscf/lassi/op_o1/stdm.py +++ b/my_pyscf/lassi/op_o1/stdm.py @@ -1,7 +1,7 @@ import numpy as np from pyscf import lib from pyscf.lib import logger -from itertools import product +from itertools import product, combinations from mrh.my_pyscf.lassi.citools import get_rootaddr_fragaddr, umat_dot_1frag_ from mrh.my_pyscf.lassi.op_o1 import frag from mrh.my_pyscf.lassi.op_o1.utilities import * @@ -107,6 +107,9 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_ for nelec_sf in self.nelec_rf] self.nelec_rf = self.nelec_rf.sum (1) + self.urootstr = np.asarray ([[i.unique_root[r] for i in self.ints] + for r in range (self.nroots)]).T + exc = self.make_exc_tables (hopping_index) self.nonuniq_exc = {} self.exc_null = self.mask_exc_table_(exc['null'], 'null', mask_bra_space, mask_ket_space) @@ -136,6 +139,14 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_ else: raise NotImplementedError (self.dtype) + def interaction_fprint (self, bra, ket, frags, ltri=False): + frags = np.sort (frags) + brastr = self.urootstr[frags,bra] + ketstr = self.urootstr[frags,ket] + if ltri: brastr, ketstr = sorted ([list(brastr),list(ketstr)]) + fprint = np.stack ([frags, brastr, ketstr], axis=0) + return fprint + def init_profiling (self): self.dt_1d, self.dw_1d = 0.0, 0.0 self.dt_2d, self.dw_2d = 0.0, 0.0 @@ -326,6 +337,7 @@ def make_exc_tables (self, hopping_index): all_interactions_full_square = False interaction_has_spin = ('_1c_', '_1c1d_', '_2c_') + ltri_ambiguous = True def mask_exc_table_(self, exc, lbl, mask_bra_space=None, mask_ket_space=None): # Part 1: restrict to the caller-specified rectangle @@ -339,28 +351,31 @@ def mask_exc_table_(self, exc, lbl, mask_bra_space=None, mask_ket_space=None): if lbl=='null': return exc ulblu = '_' + lbl + '_' excp = exc[:,:-1] if ulblu in self.interaction_has_spin else exc + fprintLT = [] fprint = [] for row in excp: - frow = [] bra, ket = row[:2] frags = row[2:] - for frag in frags: - intf = self.ints[frag] - frow.extend ([frag, intf.unique_root[bra], intf.unique_root[ket]]) - fprint.append (frow) + fpLT = self.interaction_fprint (bra, ket, frags, ltri=self.ltri_ambiguous) + fprintLT.append (fpLT.ravel ()) + fp = self.interaction_fprint (bra, ket, frags, ltri=False) + fprint.append (fp.ravel ()) + fprintLT = np.asarray (fprintLT) fprint = np.asarray (fprint) nexc = len (exc) - _, idx, inv = np.unique (fprint, axis=0, return_index=True, return_inverse=True) + fprintLT, idx, inv = np.unique (fprintLT, axis=0, return_index=True, return_inverse=True) # for some reason this squeeze is necessary for some versions of numpy; however... eqmap = np.squeeze (idx[inv]) - for uniq_idx in idx: + for fpLT, uniq_idx in zip (fprintLT, idx): row_uniq = excp[uniq_idx] # ...numpy.where (0==0) triggers a DeprecationWarning, so I have to atleast_1d it uniq_idxs = np.where (np.atleast_1d (eqmap==uniq_idx))[0] braket_images = exc[np.ix_(uniq_idxs,[0,1])] + iT = np.any (fprint[uniq_idx][None,:]!=fprint[uniq_idxs], axis=1) + braket_images[iT,:] = braket_images[iT,::-1] self.nonuniq_exc[tuple(row_uniq)] = braket_images exc = exc[idx] - nuniq = len (idx) + nuniq = len (exc) self.log.debug ('%d/%d unique interactions of %s type', nuniq, nexc, lbl) return exc diff --git a/pyscf-forge_version.txt b/pyscf-forge_version.txt index 12b66ad8..3acb68d1 100644 --- a/pyscf-forge_version.txt +++ b/pyscf-forge_version.txt @@ -1,2 +1,2 @@ -git+https://github.com/pyscf/pyscf-forge.git@b7454c8a8c8d2ad13a03e93dc70b538e38d1bf8b +git+https://github.com/pyscf/pyscf-forge.git@226601b79600961a89b034f2c541861901200b71 diff --git a/pyscf_version.txt b/pyscf_version.txt index ebb01bec..b0fcbbad 100644 --- a/pyscf_version.txt +++ b/pyscf_version.txt @@ -1 +1 @@ -git+https://github.com/pyscf/pyscf.git@afdd09fcf08c32a6de05f4ec827e8d1602531707 +git+https://github.com/pyscf/pyscf.git@bee0ce288a655105e27fcb0293b203939b7aecc9 diff --git a/tests/lassi/test_4frag.py b/tests/lassi/test_4frag.py index 523d9172..d6a28893 100644 --- a/tests/lassi/test_4frag.py +++ b/tests/lassi/test_4frag.py @@ -175,11 +175,12 @@ def test_lassis_1111 (self): lsi.prepare_states_() h0, h1, h2 = ham_2q (las1, las1.mo_coeff) case_contract_op_si (self, las1, h1, h2, lsi.ci, lsi.get_nelec_frs ()) - lsi.kernel () - self.assertTrue (lsi.converged) - self.assertAlmostEqual (lsi.e_roots[0], -1.867291372401379, 6) - case_lassis_fbf_2_model_state (self, lsi) - case_lassis_fbfdm (self, lsi) + else: + lsi.kernel () + self.assertTrue (lsi.converged) + self.assertAlmostEqual (lsi.e_roots[0], -1.867291372401379, 6) + case_lassis_fbf_2_model_state (self, lsi) + case_lassis_fbfdm (self, lsi) def test_lassis_slow (self): las0 = las.get_single_state_las (state=0) @@ -192,11 +193,12 @@ def test_lassis_slow (self): lsi.prepare_states_() h0, h1, h2 = ham_2q (las0, las0.mo_coeff) case_contract_op_si (self, las, h1, h2, lsi.ci, lsi.get_nelec_frs ()) - lsi.kernel () - self.assertTrue (lsi.converged) - self.assertAlmostEqual (lsi.e_roots[0], -304.5372586630968, 3) - case_lassis_fbf_2_model_state (self, lsi) - #case_lassis_fbfdm (self, lsi) + else: + lsi.kernel () + self.assertTrue (lsi.converged) + self.assertAlmostEqual (lsi.e_roots[0], -304.5372586630968, 3) + case_lassis_fbf_2_model_state (self, lsi) + #case_lassis_fbfdm (self, lsi) def test_fdm1 (self): make_fdm1 = get_fdm1_maker (las, las.ci, nelec_frs, si)