From a502e33ba3c600744ceb8ab656645ee4c0d31ffa Mon Sep 17 00:00:00 2001 From: Oliver Backhouse Date: Wed, 18 Dec 2024 15:28:47 +0000 Subject: [PATCH] Fixing ebcc interface --- pyproject.toml | 4 +- vayesta/solver/ebcc.py | 129 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 98d88cd2..e63f3683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dyson = [ "dyson @ git+https://github.com/BoothGroup/dyson@master", ] ebcc = [ - "ebcc<=1.5.0", + "ebcc>=1.4.0", ] pygnme = [ "pygnme @ git+https://github.com/BoothGroup/pygnme@master" @@ -75,7 +75,7 @@ dev = [ "cvxpy>=1.1", "mpi4py>=3.0.0", "dyson @ git+https://github.com/BoothGroup/dyson@master", - "ebcc<=1.5.0", + "ebcc>=1.4.0", "black>=22.6.0", "pytest", "pytest-cov", diff --git a/vayesta/solver/ebcc.py b/vayesta/solver/ebcc.py index 6b04b4d1..718cb723 100644 --- a/vayesta/solver/ebcc.py +++ b/vayesta/solver/ebcc.py @@ -1,6 +1,7 @@ import dataclasses import numpy as np +import pyscf.ao2mo from vayesta.core.types import WaveFunction, CCSD_WaveFunction, EBCC_WaveFunction from vayesta.core.util import dot, einsum @@ -8,6 +9,127 @@ import ebcc +class RERIs(ebcc.ham.base.BaseERIs, ebcc.ham.base.BaseRHamiltonian): + """Restricted ERIs container class.""" + + def __init__(self, mf, space, mo_coeff=None, array=None): + """Initialise the Hamiltonian. + + Args: + mf: Mean-field object. + space: Space object for each index. + mo_coeff: Molecular orbital coefficients for each index. + """ + ebcc.util.Namespace.__init__(self) + + # Parameters: + self.__dict__["mf"] = mf + self.__dict__["space"] = space + if isinstance(mo_coeff, tuple) and len(mo_coeff) == 4: + self.__dict__["mo_coeff"] = mo_coeff + else: + self.__dict__["mo_coeff"] = (mo_coeff,) * 4 + self.__dict__["array"] = array + + def __getitem__(self, key: str) -> np.typing.NDArray[np.floating]: + """Just-in-time getter. + + Args: + key: Key to get. + + Returns: + ERIs for the given spaces. + """ + if self.array is None: + if key not in self._members.keys(): + coeffs = [ + self.mo_coeff[i][:, self.space[i].slice(k)].astype(np.float64) + for i, k in enumerate(key) + ] + if getattr(self.mf, "_eri", None) is not None: + block = pyscf.ao2mo.incore.general(self.mf._eri, coeffs, compact=False) + else: + block = pyscf.ao2mo.kernel(self.mf.mol, coeffs, compact=False) + block = np.reshape(block, [c.shape[-1] for c in coeffs]) + self._members[key] = block + return self._members[key] + else: + ijkl = tuple(self.space[i].slice(k) for i, k in enumerate(key)) + return self.array[ijkl] # type: ignore + + +class UERIs(ebcc.ham.base.BaseERIs, ebcc.ham.base.BaseUHamiltonian): + """Unrestricted ERIs container class. + + Restores deprecated functionality allowing spin-dependent AO ERIs. + """ + + def __init__(self, mf, space, mo_coeff=None, array=None): + """Initialise the Hamiltonian. + + Args: + mf: Mean-field object. + space: Space object for each index. + mo_coeff: Molecular orbital coefficients for each index. + """ + ebcc.util.Namespace.__init__(self) + + # Parameters: + self.__dict__["mf"] = mf + self.__dict__["space"] = space + if isinstance(mo_coeff, tuple) and len(mo_coeff) == 4: + self.__dict__["mo_coeff"] = mo_coeff + else: + self.__dict__["mo_coeff"] = (mo_coeff,) * 4 + self.__dict__["array"] = array + + def __getitem__(self, key: str) -> RERIs: + """Just-in-time getter. + + Args: + key: Key to get. + + Returns: + ERIs for the given spins. + """ + if key not in ("aaaa", "aabb", "bbaa", "bbbb"): + raise KeyError(f"Invalid key: {key}") + if key not in self._members: + i = "ab".index(key[0]) + j = "ab".index(key[2]) + ij = i * (i + 1) // 2 + j + + if self.array is not None: + array = self.array[ij] + if key == "bbaa": + array = np.transpose(array, (2, 3, 0, 1)) + elif isinstance(self.mf._eri, tuple): + # Support spin-dependent integrals in the mean-field + coeffs = [ + self.mo_coeff[y][x].astype(np.float64) + for y, x in enumerate(sorted((i, i, j, j))) + ] + array = pyscf.ao2mo.incore.general(self.mf._eri[ij], coeffs, compact=False) + if key == "bbaa": + array = np.transpose(array, (2, 3, 0, 1)) + array = array + else: + array = None + + self._members[key] = RERIs( + self.mf, + (self.space[0][i], self.space[1][i], self.space[2][j], self.space[3][j]), + mo_coeff=( + self.mo_coeff[0][i], + self.mo_coeff[1][i], + self.mo_coeff[2][j], + self.mo_coeff[3][j], + ), + array=array, + ) + return self._members[key] + + class REBCC_Solver(ClusterSolver): @dataclasses.dataclass class Options(ClusterSolver.Options): @@ -30,6 +152,7 @@ def kernel(self): mycc = ebcc.EBCC( mf_clus, log=self.log, ansatz=self.opts.ansatz, space=space, shift=False, **self.get_nonnull_solver_opts() ) + mycc.ERIs = self.get_eris_class() mycc.kernel() self.converged = mycc.converged if self.opts.solve_lambda: @@ -65,6 +188,9 @@ def add_nonull_opt(d, key, newkey): add_nonull_opt(opts, key, newkey) return opts + def get_eris_class(self): + return RERIs + def construct_wavefunction(self, mycc, mo, mbos=None): if self.opts.store_as_ccsd: # Can use existing functionality @@ -130,6 +256,9 @@ def _get_space(c, occ, co_cas, cv_cas, fr): cb, spaceb = _get_space(mo_coeff[1], mo_occ[1], self.opts.c_cas_occ[1], self.opts.c_cas_vir[1], frozen[1]) return (ca, cb), (spacea, spaceb) + def get_eris_class(self): + return UERIs + # This should automatically work other than ensuring spin components are in a tuple. def construct_wavefunction(self, mycc, mo, mbos=None): if self.opts.store_as_ccsd: