From e60de4104b7f85c6e6ea18fbdf1b1c9ac58aeebb Mon Sep 17 00:00:00 2001 From: Max Nusspickel Date: Sat, 23 Sep 2023 11:19:24 +0100 Subject: [PATCH 1/5] Refactors solver/__init__.py --- vayesta/solver/__init__.py | 184 +++++++++++++++++-------------------- 1 file changed, 85 insertions(+), 99 deletions(-) diff --git a/vayesta/solver/__init__.py b/vayesta/solver/__init__.py index cfd7e0ad4..3b53fc83d 100644 --- a/vayesta/solver/__init__.py +++ b/vayesta/solver/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations +from typing import * + from vayesta.solver.ccsd import RCCSD_Solver, UCCSD_Solver from vayesta.solver.cisd import RCISD_Solver, UCISD_Solver from vayesta.solver.coupled_ccsd import coupledRCCSD_Solver @@ -11,122 +14,105 @@ try: from vayesta.solver.ebcc import REBCC_Solver, UEBCC_Solver, EB_REBCC_Solver, EB_UEBCC_Solver + _has_ebcc = True except ImportError: + REBCC_Solver = UEBCC_Solver = EB_REBCC_Solver = EB_UEBCC_Solver = None _has_ebcc = False -else: - _has_ebcc = True + +if TYPE_CHECKING: + from logging import Logger def get_solver_class(ham, solver): assert is_ham(ham) uhf = is_uhf_ham(ham) eb = is_eb_ham(ham) - return _get_solver_class(uhf, eb, solver, ham.log) + return _get_solver_class(solver, uhf, eb, ham.log) -def check_solver_config(is_uhf, is_eb, solver, log): - _get_solver_class(is_uhf, is_eb, solver, log) +def check_solver_config(solver, is_uhf, is_eb, log): + _get_solver_class(solver, is_uhf, is_eb, log) -def _get_solver_class(is_uhf, is_eb, solver, log): +def _get_solver_class(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Type: try: - solver_cls = _get_solver_class_internal(is_uhf, is_eb, solver, log) + solver_cls = _get_solver_class_internal(solver, is_uhf, is_eb, log) return solver_cls except ValueError as e: spinmessage = "unrestricted" if is_uhf else "restricted" - bosmessage = "coupled electron-boson" if is_eb else "purely electronic" - - fullmessage = f"Error; solver {solver} not available for {spinmessage} {bosmessage} systems" + ebmessage = " with electron-boson coupling" if is_eb else "" + fullmessage = f"solver '{solver}' not available for {spinmessage} systems{ebmessage}" log.critical(fullmessage) raise ValueError(fullmessage) -def _get_solver_class_internal(is_uhf, is_eb, solver, log): - # First check if we have a CC approach as implemented in pyscf. - if solver == "CCSD" and not is_eb: - # Use pyscf solvers. - if is_uhf: - return UCCSD_Solver - else: - return RCCSD_Solver - if solver == "TCCSD": - if is_uhf or is_eb: - raise ValueError("TCCSD is not implemented for unrestricted or electron-boson calculations!") - return TRCCSD_Solver - if solver == "extCCSD": - if is_eb: - raise ValueError("extCCSD is not implemented for electron-boson calculations!") - if is_uhf: - return extUCCSD_Solver - return extRCCSD_Solver - if solver == "coupledCCSD": - if is_eb: - raise ValueError("coupledCCSD is not implemented for electron-boson calculations!") - if is_uhf: - raise ValueError("coupledCCSD is not implemented for unrestricted calculations!") - return coupledRCCSD_Solver - - # Now consider general CC ansatzes; these are solved via EBCC. - # Note that we support all capitalisations of `ebcc`, but need `CC` to be capitalised when also using this to - # specify an ansatz. - if "CC" in solver.upper(): - if not _has_ebcc: - raise ImportError(f"{solver} solver is only accessible via ebcc. Please install ebcc.") - if is_uhf: - if is_eb: - solverclass = EB_UEBCC_Solver - else: - solverclass = UEBCC_Solver - else: - if is_eb: - solverclass = EB_REBCC_Solver - else: - solverclass = REBCC_Solver - if solver.upper() == "EBCC": - # Default to `opts.ansatz`. - return solverclass - if solver[:2].upper() == "EB": - solver = solver[2:] - if solver == "CCSD" and is_eb: - log.warning("CCSD solver requested for coupled electron-boson system; defaulting to CCSD-SD-1-1.") - solver = "CCSD-SD-1-1" - - # This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case. - def get_right_CC(*args, **kwargs): - setansatz = kwargs.get("ansatz", None) - if setansatz is not None: - if setansatz != solver: - raise ValueError( - "Desired CC ansatz specified differently in solver and solver_options.ansatz." - "Please use only specify via one approach, or ensure they agree." - ) - kwargs["ansatz"] = solver - return solverclass(*args, **kwargs) - - return get_right_CC - if solver == "FCI": - if is_uhf: - if is_eb: - return EB_UEBFCI_Solver - else: - return UFCI_Solver - else: - if is_eb: - return EB_EBFCI_Solver - else: - return FCI_Solver - if is_eb: - raise ValueError("%s solver is not implemented for coupled electron-boson systems!", solver) - if solver == "MP2": - if is_uhf: - return UMP2_Solver - else: - return RMP2_Solver - if solver == "CISD": - if is_uhf: - return UCISD_Solver - else: - return RCISD_Solver - if solver == "DUMP": - return DumpSolver - raise ValueError("Unknown solver: %s" % solver) +# (solver_string, is_uhf, is_eb) -> SolverClass +_solver_dict: Dict[Tuple[str, bool, bool], Type] = { + ('MP2', False, False): RMP2_Solver, + ('MP2', True, False): UMP2_Solver, + ('CISD', False, False): RCISD_Solver, + ('CISD', True, False): UCISD_Solver, + ('CCSD', False, False): RCCSD_Solver, + ('CCSD', True, False): UCCSD_Solver, + ('TCCSD', False, False): TRCCSD_Solver, + ('TCCSD', True, False): NotImplemented, + ('extCCSD', False, False): extRCCSD_Solver, + ('extCCSD', True, False): extUCCSD_Solver, + ('coupledCCSD', False, False): coupledRCCSD_Solver, + ('coupledCCSD', True, False): NotImplemented, + ('FCI', False, False): FCI_Solver, + ('FCI', True, False): UFCI_Solver, + ('FCI', False, True): EB_EBFCI_Solver, + ('FCI', True, True): EB_UEBFCI_Solver, + ('Dump', False, False): DumpSolver, + ('Dump', True, False): DumpSolver, +} + + +# (is_uhf, is_eb) -> SolverClass +_ebcc_solver_dict: Dict[Tuple[bool, bool], Type] = { + (False, False): REBCC_Solver, + (True, False): UEBCC_Solver, + (False, True): EB_REBCC_Solver, + (True, True): EB_UEBCC_Solver, +} + + +def _get_solver_class_internal(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Type | Callable: + solver_cls = _solver_dict.get((solver, is_uhf, is_eb), None) + if solver_cls is NotImplemented: + spinsym = 'unrestricted' if is_uhf else 'restricted' + raise NotImplementedError(f"solver '{solver}' for {spinsym} spin-symmetry is not implemented") + if solver_cls is not None: + return solver_cls + if 'CC' not in solver: + raise ValueError(f"unknown solver '{solver}'") + # Try EBCC next + return _get_solver_class_ebcc(solver, is_uhf, is_eb, log) + + +def _get_solver_class_ebcc(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Type | Callable: + if not _has_ebcc: + raise ImportError(f"{solver} solver is only accessible via ebcc. Please install ebcc.") + solver_cls = _ebcc_solver_dict[is_uhf, is_eb] + if solver == "EBCC": + # Default to `opts.ansatz`. + return solver_cls + if solver[:2] == "EB": + solver = solver[2:] + if solver == "CCSD" and is_eb: + log.warning("CCSD solver requested for coupled electron-boson system; defaulting to CCSD-SD-1-1.") + solver = "CCSD-SD-1-1" + + # This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case. + def get_right_cc(*args, **kwargs): + setansatz = kwargs.get("ansatz", None) + if setansatz != solver: + raise ValueError( + "Desired CC ansatz specified differently in solver and solver_options.ansatz." + "Please use only specify via one approach, or ensure they agree." + ) + kwargs["ansatz"] = solver + return solver_cls(*args, **kwargs) + + return get_right_cc From c7508f4f896bc95274a8875437a089a610230d13 Mon Sep 17 00:00:00 2001 From: Max Nusspickel Date: Sat, 23 Sep 2023 12:23:36 +0100 Subject: [PATCH 2/5] Fixes check_solver_config argument order --- vayesta/core/qemb/fragment.py | 2 +- vayesta/core/qemb/qemb.py | 2 +- vayesta/edmet/edmet.py | 2 +- vayesta/edmet/fragment.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vayesta/core/qemb/fragment.py b/vayesta/core/qemb/fragment.py index cfd2dd11b..b1cb1830f 100644 --- a/vayesta/core/qemb/fragment.py +++ b/vayesta/core/qemb/fragment.py @@ -1161,7 +1161,7 @@ def check_solver(self, solver): is_eb = "crpa_full" in self.opts.screening else: is_eb = False - check_solver_config(is_uhf, is_eb, solver, self.log) + check_solver_config(solver, is_uhf, is_eb, self.log) def get_solver(self, solver=None): if solver is None: diff --git a/vayesta/core/qemb/qemb.py b/vayesta/core/qemb/qemb.py index 1b99529f2..8652f3829 100644 --- a/vayesta/core/qemb/qemb.py +++ b/vayesta/core/qemb/qemb.py @@ -1739,4 +1739,4 @@ def check_solver(self, solver): is_eb = "crpa_full" in self.opts.screening else: is_eb = False - check_solver_config(is_uhf, is_eb, solver, self.log) + check_solver_config(solver, is_uhf, is_eb, self.log) diff --git a/vayesta/edmet/edmet.py b/vayesta/edmet/edmet.py index ce5e1b331..f17501146 100644 --- a/vayesta/edmet/edmet.py +++ b/vayesta/edmet/edmet.py @@ -66,7 +66,7 @@ def e_nonlocal(self, value): def check_solver(self, solver): is_uhf = np.ndim(self.mo_coeff[1]) == 2 is_eb = True - check_solver_config(is_uhf, is_eb, solver, self.log) + check_solver_config(solver, is_uhf, is_eb, self.log) def kernel(self): t_start = timer() diff --git a/vayesta/edmet/fragment.py b/vayesta/edmet/fragment.py index e4488736d..5a0b62b65 100644 --- a/vayesta/edmet/fragment.py +++ b/vayesta/edmet/fragment.py @@ -119,7 +119,7 @@ def energy_couplings(self, value): def check_solver(self, solver): is_uhf = np.ndim(self.base.mo_coeff[1]) == 2 is_eb = True - check_solver_config(is_uhf, is_eb, solver, self.log) + check_solver_config(solver, is_uhf, is_eb, self.log) def get_fock(self): f = self.base.get_fock() From 662d8d4189059feea7af765c5d7dc6ac759a05ea Mon Sep 17 00:00:00 2001 From: Max Nusspickel Date: Sat, 23 Sep 2023 13:37:59 +0100 Subject: [PATCH 3/5] Fixes EBCC inconsistent solver exception --- vayesta/solver/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vayesta/solver/__init__.py b/vayesta/solver/__init__.py index 3b53fc83d..2a9727f57 100644 --- a/vayesta/solver/__init__.py +++ b/vayesta/solver/__init__.py @@ -106,11 +106,11 @@ def _get_solver_class_ebcc(solver: str, is_uhf: bool, is_eb: bool, log: Logger) # This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case. def get_right_cc(*args, **kwargs): - setansatz = kwargs.get("ansatz", None) + setansatz = kwargs.get("ansatz", solver) if setansatz != solver: raise ValueError( - "Desired CC ansatz specified differently in solver and solver_options.ansatz." - "Please use only specify via one approach, or ensure they agree." + f"solver '{solver}' does not match solver_options.ansatz " + f"{'setansatz'}; only specify via one argument or ensure they agree" ) kwargs["ansatz"] = solver return solver_cls(*args, **kwargs) From 02df0f0dc7aa881a0fb7fdcb5c63ab35e2ebffca Mon Sep 17 00:00:00 2001 From: Max Nusspickel Date: Sat, 23 Sep 2023 14:54:17 +0100 Subject: [PATCH 4/5] More fixes --- vayesta/solver/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vayesta/solver/__init__.py b/vayesta/solver/__init__.py index 2a9727f57..71bf3338d 100644 --- a/vayesta/solver/__init__.py +++ b/vayesta/solver/__init__.py @@ -101,16 +101,16 @@ def _get_solver_class_ebcc(solver: str, is_uhf: bool, is_eb: bool, log: Logger) if solver[:2] == "EB": solver = solver[2:] if solver == "CCSD" and is_eb: - log.warning("CCSD solver requested for coupled electron-boson system; defaulting to CCSD-SD-1-1.") solver = "CCSD-SD-1-1" + log.warning(f"CCSD solver requested for coupled electron-boson system; defaulting to {solver}.") # This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case. def get_right_cc(*args, **kwargs): - setansatz = kwargs.get("ansatz", solver) - if setansatz != solver: + setansatz = kwargs.get("ansatz", None) + if setansatz is not None and setansatz != solver: raise ValueError( f"solver '{solver}' does not match solver_options.ansatz " - f"{'setansatz'}; only specify via one argument or ensure they agree" + f"'{setansatz}'; only specify via one argument or ensure they agree" ) kwargs["ansatz"] = solver return solver_cls(*args, **kwargs) From 7888e35a1ad1a9c0cb17f3586e668821f33c3e88 Mon Sep 17 00:00:00 2001 From: Basil Ibrahim Date: Tue, 20 Aug 2024 21:19:04 +0100 Subject: [PATCH 5/5] Fix solver captialisation --- vayesta/solver/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vayesta/solver/__init__.py b/vayesta/solver/__init__.py index b7756b103..e4523fe39 100644 --- a/vayesta/solver/__init__.py +++ b/vayesta/solver/__init__.py @@ -65,9 +65,9 @@ def _get_solver_class(solver: str, is_uhf: bool, is_eb: bool, log: Logger) -> Ty ('FCI', True, False): UFCI_Solver, ('FCI', False, True): EB_EBFCI_Solver, ('FCI', True, True): EB_UEBFCI_Solver, - ('Dump', False, False): DumpSolver, - ('Dump', True, False): DumpSolver, - ('Callback', False, False): CallbackSolver, + ('DUMP', False, False): DumpSolver, + ('DUMP', True, False): DumpSolver, + ('CALLBACK', False, False): CallbackSolver, }