Skip to content

Commit

Permalink
Merge branch 'fix_vayesta_interface' into scipy_version
Browse files Browse the repository at this point in the history
  • Loading branch information
basilib committed Dec 20, 2024
2 parents 8c795e5 + a502e33 commit ed377b3
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ dyson = [
"dyson @ git+https://github.com/BoothGroup/dyson@master",
]
ebcc = [
"ebcc<=1.5.0",
"ebcc>=1.4.0",
]
pygnme = [
"pygnme @ git+https://github.com/BoothGroup/pygnme@master"
Expand All @@ -76,7 +76,7 @@ dev = [
"cvxpy>=1.1",
"mpi4py>=3.0.0",
"dyson @ git+https://github.com/BoothGroup/dyson@master",
"ebcc<=1.5.0",
"ebcc>=1.4.0",
"black>=22.6.0",
"pytest",
"pytest-cov",
Expand Down
129 changes: 129 additions & 0 deletions vayesta/solver/ebcc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,135 @@
import dataclasses

import numpy as np
import pyscf.ao2mo

from vayesta.core.types import WaveFunction, CCSD_WaveFunction, EBCC_WaveFunction
from vayesta.core.util import dot, einsum
from vayesta.solver.solver import ClusterSolver, UClusterSolver
import ebcc


class RERIs(ebcc.ham.base.BaseERIs, ebcc.ham.base.BaseRHamiltonian):
"""Restricted ERIs container class."""

def __init__(self, mf, space, mo_coeff=None, array=None):
"""Initialise the Hamiltonian.
Args:
mf: Mean-field object.
space: Space object for each index.
mo_coeff: Molecular orbital coefficients for each index.
"""
ebcc.util.Namespace.__init__(self)

# Parameters:
self.__dict__["mf"] = mf
self.__dict__["space"] = space
if isinstance(mo_coeff, tuple) and len(mo_coeff) == 4:
self.__dict__["mo_coeff"] = mo_coeff
else:
self.__dict__["mo_coeff"] = (mo_coeff,) * 4
self.__dict__["array"] = array

def __getitem__(self, key: str) -> np.typing.NDArray[np.floating]:
"""Just-in-time getter.
Args:
key: Key to get.
Returns:
ERIs for the given spaces.
"""
if self.array is None:
if key not in self._members.keys():
coeffs = [
self.mo_coeff[i][:, self.space[i].slice(k)].astype(np.float64)
for i, k in enumerate(key)
]
if getattr(self.mf, "_eri", None) is not None:
block = pyscf.ao2mo.incore.general(self.mf._eri, coeffs, compact=False)
else:
block = pyscf.ao2mo.kernel(self.mf.mol, coeffs, compact=False)
block = np.reshape(block, [c.shape[-1] for c in coeffs])
self._members[key] = block
return self._members[key]
else:
ijkl = tuple(self.space[i].slice(k) for i, k in enumerate(key))
return self.array[ijkl] # type: ignore


class UERIs(ebcc.ham.base.BaseERIs, ebcc.ham.base.BaseUHamiltonian):
"""Unrestricted ERIs container class.
Restores deprecated functionality allowing spin-dependent AO ERIs.
"""

def __init__(self, mf, space, mo_coeff=None, array=None):
"""Initialise the Hamiltonian.
Args:
mf: Mean-field object.
space: Space object for each index.
mo_coeff: Molecular orbital coefficients for each index.
"""
ebcc.util.Namespace.__init__(self)

# Parameters:
self.__dict__["mf"] = mf
self.__dict__["space"] = space
if isinstance(mo_coeff, tuple) and len(mo_coeff) == 4:
self.__dict__["mo_coeff"] = mo_coeff
else:
self.__dict__["mo_coeff"] = (mo_coeff,) * 4
self.__dict__["array"] = array

def __getitem__(self, key: str) -> RERIs:
"""Just-in-time getter.
Args:
key: Key to get.
Returns:
ERIs for the given spins.
"""
if key not in ("aaaa", "aabb", "bbaa", "bbbb"):
raise KeyError(f"Invalid key: {key}")
if key not in self._members:
i = "ab".index(key[0])
j = "ab".index(key[2])
ij = i * (i + 1) // 2 + j

if self.array is not None:
array = self.array[ij]
if key == "bbaa":
array = np.transpose(array, (2, 3, 0, 1))
elif isinstance(self.mf._eri, tuple):
# Support spin-dependent integrals in the mean-field
coeffs = [
self.mo_coeff[y][x].astype(np.float64)
for y, x in enumerate(sorted((i, i, j, j)))
]
array = pyscf.ao2mo.incore.general(self.mf._eri[ij], coeffs, compact=False)
if key == "bbaa":
array = np.transpose(array, (2, 3, 0, 1))
array = array
else:
array = None

self._members[key] = RERIs(
self.mf,
(self.space[0][i], self.space[1][i], self.space[2][j], self.space[3][j]),
mo_coeff=(
self.mo_coeff[0][i],
self.mo_coeff[1][i],
self.mo_coeff[2][j],
self.mo_coeff[3][j],
),
array=array,
)
return self._members[key]


class REBCC_Solver(ClusterSolver):
@dataclasses.dataclass
class Options(ClusterSolver.Options):
Expand All @@ -30,6 +152,7 @@ def kernel(self):
mycc = ebcc.EBCC(
mf_clus, log=self.log, ansatz=self.opts.ansatz, space=space, shift=False, **self.get_nonnull_solver_opts()
)
mycc.ERIs = self.get_eris_class()
mycc.kernel()
self.converged = mycc.converged
if self.opts.solve_lambda:
Expand Down Expand Up @@ -65,6 +188,9 @@ def add_nonull_opt(d, key, newkey):
add_nonull_opt(opts, key, newkey)
return opts

def get_eris_class(self):
return RERIs

def construct_wavefunction(self, mycc, mo, mbos=None):
if self.opts.store_as_ccsd:
# Can use existing functionality
Expand Down Expand Up @@ -130,6 +256,9 @@ def _get_space(c, occ, co_cas, cv_cas, fr):
cb, spaceb = _get_space(mo_coeff[1], mo_occ[1], self.opts.c_cas_occ[1], self.opts.c_cas_vir[1], frozen[1])
return (ca, cb), (spacea, spaceb)

def get_eris_class(self):
return UERIs

# This should automatically work other than ensuring spin components are in a tuple.
def construct_wavefunction(self, mycc, mo, mbos=None):
if self.opts.store_as_ccsd:
Expand Down

0 comments on commit ed377b3

Please sign in to comment.