Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback solver #161

Merged
merged 14 commits into from
Mar 30, 2024
56 changes: 56 additions & 0 deletions examples/ewf/molecules/65-callback-solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pyscf
import pyscf.gto
import pyscf.scf
import pyscf.fci
import vayesta
import vayesta.ewf
from vayesta.misc.molecules import ring

# User defined FCI solver - takes pyscf mf as input and returns RDMs
def solver(mf):
h1e = mf.get_hcore()
h2e = mf._eri
norb = mf.mo_coeff.shape[-1]
nelec = mf.mol.nelec
energy, civec = pyscf.fci.direct_spin0.kernel(h1e, h2e, norb, nelec, conv_tol=1.e-14)
dm1, dm2 = pyscf.fci.direct_spin0.make_rdm12(civec, norb, nelec)
results = dict(dm1=dm1, dm2=dm2, converged=True)
return results

natom = 10
mol = pyscf.gto.Mole()
mol.atom = ring("H", natom, 1.5)
mol.basis = "sto-3g"
mol.output = "pyscf.out"
mol.verbose = 5
mol.symmetry = True
mol.build()

# Hartree-Fock
mf = pyscf.scf.RHF(mol)
mf.kernel()

# Vayesta options
use_sym = True
nfrag = 1
bath_opts = dict(bathtype="dmet")

# Run vayesta with user defined solver
emb = vayesta.ewf.EWF(mf, solver="CALLBACK", energy_functional='dm', bath_options=bath_opts, solver_options=dict(callback=solver))
# Set up fragments
with emb.iao_fragmentation() as f:
if use_sym:
# Add rotational symmetry
with f.rotational_symmetry(order=natom//nfrag, axis=[0, 0, 1]):
f.add_atomic_fragment(range(nfrag))
else:
# Add all atoms as separate fragments
f.add_all_atomic_fragments()
emb.kernel()

print("Hartree-Fock energy : %s"%mf.e_tot)
print("DMET energy : %s"%emb.get_dmet_energy(part_cumulant=False, approx_cumulant=False))
print("DMET energy (part-cumulant): %s"%emb.get_dmet_energy(part_cumulant=True, approx_cumulant=False))
print("DMET energy (approx-cumulant): %s"%emb.get_dmet_energy(part_cumulant=True, approx_cumulant=True))

3 changes: 1 addition & 2 deletions vayesta/core/qemb/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def get_fragment_dmet_energy(
"""
assert mpi.rank == self.mpi_rank
if dm1 is None:
dm1 = self.results.dm1
dm1 = self.results.dm1.copy()
if dm1 is None:
raise RuntimeError("DM1 not found for %s" % self)
c_act = self.cluster.c_active
Expand All @@ -1069,7 +1069,6 @@ def get_fragment_dmet_energy(

if dm2 is None:
dm2 = self.results.wf.make_rdm2(with_dm1=not part_cumulant, approx_cumulant=approx_cumulant)

# Get effective core potential
if h1e_eff is None:
if part_cumulant:
Expand Down
3 changes: 3 additions & 0 deletions vayesta/core/qemb/qemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,11 @@ class Options(OptionsBase):
store_as_ccsd=None,
# Dump
dumpfile="clusters.h5",
# Callback
callback = None,
# MP2
compress_cderi=False,

)
# --- Other
symmetry_tol: float = 1e-6 # Tolerance (in Bohr) for atomic positions
Expand Down
4 changes: 4 additions & 0 deletions vayesta/core/types/wf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vayesta.core.types.wf.cisd import CISD_WaveFunction, RCISD_WaveFunction, UCISD_WaveFunction
from vayesta.core.types.wf.ccsd import CCSD_WaveFunction, RCCSD_WaveFunction, UCCSD_WaveFunction
from vayesta.core.types.wf.fci import FCI_WaveFunction, RFCI_WaveFunction, UFCI_WaveFunction
from vayesta.core.types.wf.rdm import RDM_WaveFunction, RRDM_WaveFunction, URDM_WaveFunction

# WIP:
from vayesta.core.types.wf.cisdtq import CISDTQ_WaveFunction, RCISDTQ_WaveFunction, UCISDTQ_WaveFunction
Expand Down Expand Up @@ -38,4 +39,7 @@
"CCSDTQ_WaveFunction",
"RCCSDTQ_WaveFunction",
"UCCSDTQ_WaveFunction",
"RDM_WaveFunction",
"RRDM_WaveFunction",
"URDM_WaveFunction",
]
110 changes: 110 additions & 0 deletions vayesta/core/types/wf/rdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numpy as np

from vayesta.core import spinalg
from vayesta.core.util import dot, einsum, callif
from vayesta.core.types import wf as wf_types
from vayesta.core.types.wf.project import (
project_c1,
project_c2,
project_uc1,
project_uc2,
symmetrize_c2,
symmetrize_uc2,
transform_c1,
transform_c2,
transform_uc1,
transform_uc2,
)

def RDM_WaveFunction(mo, dm1, dm2, **kwargs):
if mo.nspin == 1:
cls = RRDM_WaveFunction
elif mo.nspin == 2:
cls = URDM_Wavefunction
return cls(mo, dm1, dm2, **kwargs)

class RRDM_WaveFunction(wf_types.WaveFunction):
"""
Dummy wavefunction type that stores the 1- and 2-RDMs.
Allows interoperability with user-defined callback solvers
which can only return the 1- and 2-RDMs.
"""
def __init__(self, mo, dm1, dm2, projector=None):
super().__init__(mo, projector=projector)
self.dm1 = dm1
self.dm2 = dm2

def make_rdm1(self, ao_basis=False, with_mf=True):
dm1 = self.dm1.copy()
if not with_mf:
dm1[np.diag_indices(self.nocc)] -= 2
if not ao_basis:
return dm1
return dot(self.mo.coeff, dm1, self.mo.coeff.T)

def make_rdm2(self, ao_basis=False, with_dm1=True, approx_cumulant=True):
dm1, dm2 = self.dm1.copy(), self.dm2.copy()
if not with_dm1:
if not approx_cumulant:
dm2 -= einsum("ij,kl->ijkl", dm1, dm1) - einsum("ij,kl->iklj", dm1, dm1) / 2
elif approx_cumulant in (1, True):
dm1[np.diag_indices(self.nocc)] -= 1
for i in range(self.nocc):
dm2[i, i, :, :] -= 2 * dm1
dm2[:, :, i, i] -= 2 * dm1
dm2[:, i, i, :] += dm1
dm2[i, :, :, i] += dm1
elif approx_cumulant == 2:
raise NotImplementedError
else:
raise ValueError
if not ao_basis:
return dm2
return einsum("ijkl,ai,bj,ck,dl->abcd", dm2, *(4 * [self.mo.coeff]))

def make_rdm2_non_cumulant(self, ao_basis=False):
dm1 = self.dm1.copy()
dm2 = einsum("ij,kl->ijkl", dm1, dm1) - einsum("ij,kl->iklj", dm1, dm1) / 2
if not ao_basis:
return dm2
return einsum("ijkl,ai,bj,ck,dl->abcd", dm2, *(4 * [self.mo.coeff]))

def copy(self):
dm1 = spinalg.copy(self.dm1)
dm2 = spinalg.copy(self.dm2)
proj = callif(spinalg.copy, self.projector)
return type(self)(self.mo.copy(), dm1, dm2, projector=proj)

def project(self, projector, inplace):
wf = self if inplace else self.copy()
wf.dm1 = project_c1(wf.dm1, projector)
wf.dm2 = project_c2(wf.dm2, projector)
wf.projector = projector
return wf

def pack(self, dtype=float):
"""Pack into a single array of data type `dtype`.

Useful for communication via MPI."""
mo = self.mo.pack(dtype=dtype)
data = (mo, dm1, dm2, self.projector)
pack = pack_arrays(*data, dtype=dtype)
return pack

@classmethod
def unpack(cls, packed):
"""Unpack from a single array of data type `dtype`.

Useful for communication via MPI."""
mo, dm1, dm2, projector = unpack_arrays(packed)
mo = SpatialOrbitals.unpack(mo)
return cls(mo, dm1, dm2, projector=projector)

def restore(self):
raise NotImplementedError()

def as_unrestricted(self):
raise NotImplementedError()

class URDM_WaveFunction(RRDM_WaveFunction):
pass
7 changes: 6 additions & 1 deletion vayesta/ewf/ewf.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def tailor_all_fragments(self):
def kernel(self):
"""Run EWF."""
t_start = timer()

print("EWF opts \n%s\n"%str(self.opts.solver_options))
# Automatic fragmentation
if len(self.fragments) == 0:
self.log.debug("No fragments found. Adding all atomic IAO fragments.")
Expand Down Expand Up @@ -180,8 +180,13 @@ def kernel(self):
self.log.error("Some fragments did not converge!")
self.converged = conv

if self.solver.lower() == "callback":
self.log.info("Total wall time: %s", time_string(timer() - t_start))
return
# --- Evaluate correlation energy and log information
self.e_corr = self.get_e_corr()


self.log.output("E(MF)= %s", energy_string(self.e_mf))
self.log.output("E(corr)= %s", energy_string(self.e_corr))
self.log.output("E(tot)= %s", energy_string(self.e_tot))
Expand Down
19 changes: 15 additions & 4 deletions vayesta/ewf/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vayesta.core.util import deprecated, dot, einsum, energy_string, getattr_recursive, hstack, log_method, log_time
from vayesta.core.qemb import Fragment as BaseFragment
from vayesta.core.fragmentation import IAO_Fragmentation
from vayesta.core.types import RFCI_WaveFunction, RCCSDTQ_WaveFunction, UCCSDTQ_WaveFunction
from vayesta.core.types import RFCI_WaveFunction, RCCSDTQ_WaveFunction, UCCSDTQ_WaveFunction, RDM_WaveFunction, RRDM_WaveFunction, URDM_WaveFunction
from vayesta.core.bath import DMET_Bath
from vayesta.mpi import mpi

Expand Down Expand Up @@ -74,7 +74,8 @@ class Results(BaseFragment.Results):
ip_energy: np.ndarray = None
ea_energy: np.ndarray = None
moms: tuple = None

callback_results: dict = None

@property
def dm1(self):
"""Cluster 1DM"""
Expand All @@ -85,6 +86,7 @@ def dm2(self):
"""Cluster 2DM"""
return self.wf.make_rdm2()


def __init__(self, *args, **kwargs):
"""
Parameters
Expand Down Expand Up @@ -234,6 +236,7 @@ def kernel(self, solver=None, init_guess=None):
if not self.base.opts._debug_wf:
with log_time(self.log.info, ("Time for %s solver:" % solver) + " %s"):
cluster_solver.kernel()

# Special debug "solver"
else:
if self.base.opts._debug_wf == "random":
Expand All @@ -259,13 +262,18 @@ def kernel(self, solver=None, init_guess=None):
# Projection of CCSDTQ wave function is not implemented - convert to CCSD
elif isinstance(wf, (RCCSDTQ_WaveFunction, UCCSDTQ_WaveFunction)):
pwf = wf.as_ccsd()
proj = self.get_overlap("proj|cluster-occ")
if isinstance(wf, (RRDM_WaveFunction, URDM_WaveFunction)):
proj = self.get_overlap("cluster|frag")
proj = proj @ proj.T
else:
proj = self.get_overlap("proj|cluster-occ")
pwf = pwf.project(proj, inplace=False)

# Moments

moms = cluster_solver.hole_moments, cluster_solver.particle_moments

callback_results = cluster_solver.callback_results if solver.lower() == "callback" else None
# --- Add to results data class
self._results = results = self.Results(
fid=self.id,
Expand All @@ -275,12 +283,13 @@ def kernel(self, solver=None, init_guess=None):
pwf=pwf,
moms=moms,
e_corr_rpa=e_corr_rpa,
callback_results=callback_results,
)

self.hamil = cluster_solver.hamil

# --- Correlation energy contributions
if self.opts.calc_e_wf_corr:
if self.opts.calc_e_wf_corr and not isinstance(wf, (RRDM_WaveFunction, URDM_WaveFunction)):
ci = wf.as_cisd(c0=1.0)

ci = ci.project(proj)
Expand Down Expand Up @@ -324,6 +333,8 @@ def get_solver_options(self, solver):
solver_opts["tcc_fci_opts"] = self.opts.tcc_fci_opts
elif solver.upper() == "DUMP":
solver_opts["filename"] = self.opts.solver_options["dumpfile"]
if solver.upper() == 'CALLBACK':
solver_opts["callback"] = self.opts.solver_options["callback"]
solver_opts["external_corrections"] = self.flags.external_corrections
solver_opts["test_extcorr"] = self.flags.test_extcorr
return solver_opts
Expand Down
3 changes: 3 additions & 0 deletions vayesta/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from vayesta.solver.cisd import RCISD_Solver, UCISD_Solver
from vayesta.solver.coupled_ccsd import coupledRCCSD_Solver
from vayesta.solver.dump import DumpSolver
from vayesta.solver.callback import CallbackSolver
from vayesta.solver.ebfci import EB_EBFCI_Solver, EB_UEBFCI_Solver
from vayesta.solver.ext_ccsd import extRCCSD_Solver, extUCCSD_Solver
from vayesta.solver.fci import FCI_Solver, UFCI_Solver
Expand Down Expand Up @@ -129,4 +130,6 @@ def get_right_CC(*args, **kwargs):
return RCISD_Solver
if solver == "DUMP":
return DumpSolver
if solver == 'CALLBACK':
return CallbackSolver
raise ValueError("Unknown solver: %s" % solver)
50 changes: 50 additions & 0 deletions vayesta/solver/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import dataclasses
from typing import Callable
import numpy as np

from vayesta.core.types import CISD_WaveFunction, CCSD_WaveFunction, FCI_WaveFunction, RDM_WaveFunction
from vayesta.solver.solver import ClusterSolver

class CallbackSolver(ClusterSolver):
@dataclasses.dataclass
class Options(ClusterSolver.Options):
# Need to specify a type for this to work
callback: int = None

def kernel(self, *args, **kwargs):
mf_clus, frozen = self.hamil.to_pyscf_mf(allow_dummy_orbs=True, allow_df=True)
results = self.opts.callback(mf_clus)

# Build appropriate wavefunction object
if 'civec' in results:
self.log.info("FCI WaveFunction found in callback results.")
wf = FCI_WaveFunction(self.hamil.mo, results['civec'])
elif 't1' in results and 't2' in results:
self.log.info("CCSD WaveFunction found in callback results.")
t1, t2 = results['t1'], results['t2']
if 'l1' in results and 'l2' in results:
l1, l2 = results['l1'], results['l2']
else:
l1, l2 = None, None
wf = CCSD_WaveFunction(self.hamil.mo, t1, t2, l1=l1, l2=l2)
elif 'c0' in results and 'c1' in results and 'c2' in results:
self.log.info("CISD WaveFunction found in callback results.")
c0, c1, c2 = results['c0'], results['c1'], results['c2']
wf = CISD_WaveFunction(self.hamil.mo, c0, c1, c2)
elif 'dm1' in results and 'dm2' in results:
self.log.info("RDM WaveFunction found in callback results.")
dm1, dm2 = results['dm1'], results['dm2']
wf = RDM_WaveFunction(self.hamil.mo, dm1, dm2)
else:
self.log.warn("No wavefunction results returned by callback!")

if 'hole_moments' in results:
self.log.info("Hole moments found in callback results.")
self.hole_moments = results['hole_moments']
if 'particle_moments' in results:
self.log.info("Particle moments found in callback results.")
self.particle_moments = results['particle_moments']

results['wf'] = wf
self.wf = wf
self.callback_results = results
Loading
Loading