diff --git a/gusto/forcing.py b/gusto/forcing.py index 2ed43be0d..e5be8d371 100644 --- a/gusto/forcing.py +++ b/gusto/forcing.py @@ -8,7 +8,7 @@ from gusto.labels import ( transport, diffusion, time_derivative, hydrostatic ) -from gusto.logging import logger, DEBUG +from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual __all__ = ["Forcing"] @@ -97,21 +97,19 @@ def __init__(self, equation, alpha): a.form, L_implicit.form, self.xF, bcs=bcs ) - solver_parameters = {} - if logger.isEnabledFor(DEBUG): - solver_parameters["ksp_monitor_true_residual"] = None - self.solvers = {} self.solvers["explicit"] = LinearVariationalSolver( explicit_forcing_problem, - solver_parameters=solver_parameters, options_prefix="ExplicitForcingSolver" ) self.solvers["implicit"] = LinearVariationalSolver( implicit_forcing_problem, - solver_parameters=solver_parameters, options_prefix="ImplicitForcingSolver" ) + + if logger.isEnabledFor(DEBUG): + self.solvers["explicit"].snes.ksp.setMonitor(logging_ksp_monitor_true_residual) + self.solvers["implicit"].snes.ksp.setMonitor(logging_ksp_monitor_true_residual) def apply(self, x_in, x_nl, x_out, label): """ diff --git a/gusto/linear_solvers.py b/gusto/linear_solvers.py index 4942031b2..77c337b91 100644 --- a/gusto/linear_solvers.py +++ b/gusto/linear_solvers.py @@ -15,7 +15,7 @@ from pyop2.profiling import timed_function, timed_region from gusto.active_tracers import TracerVariableType -from gusto.logging import logger, DEBUG +from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.labels import linearisation, time_derivative, hydrostatic from gusto import thermodynamics from gusto.fml.form_manipulation_language import Term, drop @@ -55,17 +55,26 @@ def __init__(self, equations, alpha=0.5, solver_parameters=None, solver_parameters = p self.solver_parameters = solver_parameters - if logger.isEnabledFor(DEBUG): - self.solver_parameters["ksp_monitor_true_residual"] = None + # ~ if logger.isEnabledFor(DEBUG): + # ~ self.solver_parameters["ksp_monitor_true_residual"] = None # setup the solver self._setup_solver() - + + @staticmethod + def log_ksp_residuals(ksp): + if logger.isEnabledFor(DEBUG): + ksp.setMonitor(logging_ksp_monitor_true_residual) + @abstractproperty def solver_parameters(self): """Solver parameters for this solver""" pass - + + @abstractmethod + def _setup_solver(self): + pass + @abstractmethod def solve(self): pass @@ -151,11 +160,6 @@ def __init__(self, equations, alpha=0.5, logger.warning("default quadrature degree most likely not sufficient for this degree element") self.quadrature_degree = (5, 5) - if logger.isEnabledFor(DEBUG): - self.solver_parameters["ksp_monitor_true_residual"] = None - # Turn monitor on for the trace system too - self.solver_parameters["condensed_field"]["ksp_monitor_true_residual"] = None - super().__init__(equations, alpha, solver_parameters, overwrite_solver_parameters) @@ -344,7 +348,14 @@ def L_tr(f): # Store boundary conditions for the div-conforming velocity to apply # post-solve self.bcs = self.equations.bcs['u'] - + + # Log residuals on hybridized solver + self.log_ksp_residuals(self.hybridized_solver.snes.ksp) + # Log residuals on the trace system too + from gusto.logging import attach_custom_monitor + python_context = self.hybridized_solver.snes.ksp.pc.getPythonContext() + attach_custom_monitor(python_context, logging_ksp_monitor_true_residual) + @timed_function("Gusto:LinearSolve") def solve(self, xrhs, dy): """ @@ -500,6 +511,9 @@ def trace_nullsp(T): rhs(b_eqn), self.b) self.b_solver = LinearVariationalSolver(b_problem) + + # Log residuals on hybridized solver + self.log_ksp_residuals(self.up_solver.snes.ksp) @timed_function("Gusto:LinearSolve") def solve(self, xrhs, dy): diff --git a/gusto/logging.py b/gusto/logging.py index 9db5d9211..17838100a 100644 --- a/gusto/logging.py +++ b/gusto/logging.py @@ -33,6 +33,8 @@ from logging import NOTSET, DEBUG, INFO, WARNING, ERROR, CRITICAL # noqa: F401 from pathlib import Path +from firedrake.slate.static_condensation import scpc, hybridization +from petsc4py import PETSc from pyop2.mpi import COMM_WORLD __all__ = [ @@ -186,3 +188,103 @@ def update_logfile_location(new_path): "More than one log handler with name `gusto-temp-file-log`\n" "Logging has been set up incorrectly" ) + + +# We want a map from ENUM to Norm names +_norm_to_enum = {k: v for k, v in PETSc.KSP.NormType.__dict__.items() if isinstance(v, int)} +_enum_to_norm = {v: k.lower() for k, v in _norm_to_enum.items() if 'NORM_' not in k} + + +# The logging monitors will only log at level debug, but you should avoid +# adding an expensive Python callback the log level is not DEBUG by +# checking the logger like so: +# ``` +# if logger.isEnabledFor(DEBUG): +# ksp.setMonitor(logging_ksp_monitor) +# ``` +def logging_ksp_monitor(ksp, iteration, residual_norm): + ''' + Clone of C code at: + https://petsc.org/main/src/ksp/ksp/interface/iterativ.c.html#KSPMonitorResidual + Example output: + Residual norms for firedrake_0_ solve + 0 KSP Residual norm 3.175267221735e+00 + + ''' + tab_level = ksp.getTabLevel() + tab = ' ' + if iteration == 0: + logger.debug(tab*tab_level + f'Residual norms for {ksp.prefix} solve') + logger.debug( + tab*(tab_level - 1) + + f'{iteration: 5d} KSP Residual norm {residual_norm:14.12e}' + ) + + +def logging_ksp_monitor_true_residual(ksp, iteration, residual_norm): + ''' + Clone of C code: + https://petsc.org/main/src/ksp/ksp/interface/iterativ.c.html#KSPMonitorTrueResidual + Example output: + Residual norms for firedrake_0_ solve + 0 KSP preconditioned resid norm 3.175267221735e+00 true resid norm 3.175267221735e+00 ||r(i)||/||b|| 1.000000000000e+00 + + ''' + tab_level = ksp.getTabLevel() + tab = ' ' + residual = ksp.buildResidual() + true_norm = residual.norm(PETSc.NormType.NORM_2) + bnorm = ksp.vec_rhs.norm(PETSc.NormType.NORM_2) + if bnorm == 0: + residual_over_b = float('inf') + else: + residual_over_b = true_norm / bnorm + if iteration == 0: + logger.debug(tab*tab_level + f'Residual norms for {ksp.prefix} solve') + logger.debug( + tab*(tab_level - 1) + + f'{iteration: 5d} KSP {_enum_to_norm[ksp.norm_type]} resid norm {residual_norm:14.12e}' + + f' true resid norm {true_norm:14.12e}' + + f' ||r(i)||/||b|| {residual_over_b:14.12e}' + ) + + +def _wrap_method(obj, method_str, ksp_str, monitor): + ''' + Used to patch the method with name `method_str` of the object `obj`, + by setting the monitor of the solver with name `ksp_str` to `monitor`. + + Intended use: + ``` + foo.initialize = _wraps_initialize( + context + "initialize", + "my_ksp", + my_custom_monitor + ) + ``` + + If this is confusing, do not try and call this function! + ''' + old_init = getattr(obj, method_str) + def new_init(pc): + old_init(pc) + getattr(obj, ksp_str).setMonitor(monitor) + return new_init + + +def attach_custom_monitor(context, monitor): + if isinstance(context, scpc.SCPC): + context.initialize = _wrap_method( + context, + "initialize", + "condensed_ksp", + monitor + ) + elif isinstance(context, hybridization.HybridizationPC): + context.initialize = _wrap_method( + context, + "initialize", + "trace_ksp", + monitor + ) diff --git a/gusto/time_discretisation.py b/gusto/time_discretisation.py index 602494efd..98e8d53f5 100644 --- a/gusto/time_discretisation.py +++ b/gusto/time_discretisation.py @@ -19,7 +19,7 @@ replace_subject, replace_test_function, Term, all_terms, drop ) from gusto.labels import time_derivative, prognostic, physics -from gusto.logging import logger, DEBUG +from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.wrappers import * @@ -98,8 +98,6 @@ def __init__(self, domain, field_name=None, solver_parameters=None, 'sub_pc_type': 'ilu'} else: self.solver_parameters = solver_parameters - if logger.isEnabledFor(DEBUG): - self.solver_parameters["ksp_monitor_true_residual"] = None def setup(self, equation, apply_bcs=True, *active_labels): """ @@ -215,7 +213,14 @@ def solver(self): # setup solver using lhs and rhs defined in derived class problem = NonlinearVariationalProblem(self.lhs-self.rhs, self.x_out, bcs=self.bcs) solver_name = self.field_name+self.__class__.__name__ - return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name) + solver = NonlinearVariationalSolver( + problem, + solver_parameters=self.solver_parameters, + options_prefix=solver_name + ) + if logger.isEnabledFor(DEBUG): + solver.snes.ksp.setMonitor(logging_ksp_monitor_true_residual) + return solver @abstractmethod def apply(self, x_out, x_in):