From 3c563a90dd3422fb3d1ea45710bad10ed073cfeb Mon Sep 17 00:00:00 2001 From: Thomas Bendall Date: Fri, 1 Nov 2024 15:38:33 +0000 Subject: [PATCH 1/4] Conservative projection (#508) Co-authored-by: Tim Andrews --- gusto/core/__init__.py | 23 +- gusto/core/configuration.py | 20 +- gusto/core/conservative_projection.py | 93 ++++++++ gusto/equations/prognostic_equations.py | 11 + gusto/recovery/reversible_recovery.py | 72 +++++- .../time_discretisation.py | 3 +- gusto/time_discretisation/wrappers.py | 84 ++++++- .../test_tracer_conservative_transport.py | 208 ++++++++++++++++++ .../test_conservative_recovery.py | 112 ++++++++++ .../test_reversible_recovery.py | 4 +- unit-tests/test_conservative_projection.py | 64 ++++++ 11 files changed, 667 insertions(+), 27 deletions(-) create mode 100644 gusto/core/conservative_projection.py create mode 100644 integration-tests/transport/test_tracer_conservative_transport.py create mode 100644 unit-tests/recovery_tests/test_conservative_recovery.py create mode 100644 unit-tests/test_conservative_projection.py diff --git a/gusto/core/__init__.py b/gusto/core/__init__.py index 163f9ddff..432c66b9d 100644 --- a/gusto/core/__init__.py +++ b/gusto/core/__init__.py @@ -1,11 +1,12 @@ -from gusto.core.configuration import * # noqa -from gusto.core.coordinates import * # noqa -from gusto.core.coord_transforms import * # noqa -from gusto.core.domain import * # noqa -from gusto.core.fields import * # noqa -from gusto.core.function_spaces import * # noqa -from gusto.core.io import * # noqa -from gusto.core.kernels import * # noqa -from gusto.core.labels import * # noqa -from gusto.core.logging import * # noqa -from gusto.core.meshes import * # noqa \ No newline at end of file +from gusto.core.configuration import * # noqa +from gusto.core.conservative_projection import * # noqa +from gusto.core.coordinates import * # noqa +from gusto.core.coord_transforms import * # noqa +from gusto.core.domain import * # noqa +from gusto.core.fields import * # noqa +from gusto.core.function_spaces import * # noqa +from gusto.core.io import * # noqa +from gusto.core.kernels import * # noqa +from gusto.core.labels import * # noqa +from gusto.core.logging import * # noqa +from gusto.core.meshes import * # noqa \ No newline at end of file diff --git a/gusto/core/configuration.py b/gusto/core/configuration.py index 252f25187..a3796b40d 100644 --- a/gusto/core/configuration.py +++ b/gusto/core/configuration.py @@ -8,7 +8,8 @@ "IntegrateByParts", "TransportEquationType", "OutputParameters", "BoussinesqParameters", "CompressibleParameters", "ShallowWaterParameters", - "EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions", "MixedFSOptions", + "EmbeddedDGOptions", "ConservativeEmbeddedDGOptions", "RecoveryOptions", + "ConservativeRecoveryOptions", "SUPGOptions", "MixedFSOptions", "SpongeLayerParameters", "DiffusionParameters", "BoundaryLayerParameters" ] @@ -164,6 +165,14 @@ class EmbeddedDGOptions(WrapperOptions): embedding_space = None +class ConservativeEmbeddedDGOptions(EmbeddedDGOptions): + """Specifies options for a conservative embedded DG method.""" + + project_back_method = 'conservative_project' + rho_name = None + orig_rho_space = None + + class RecoveryOptions(WrapperOptions): """Specifies options for a recovery wrapper method.""" @@ -177,6 +186,15 @@ class RecoveryOptions(WrapperOptions): broken_method = 'interpolate' +class ConservativeRecoveryOptions(RecoveryOptions): + """Specifies options for a conservative recovery wrapper method.""" + + rho_name = None + orig_rho_space = None + project_high_method = 'conservative_project' + project_low_method = 'conservative_project' + + class SUPGOptions(WrapperOptions): """Specifies options for an SUPG scheme.""" diff --git a/gusto/core/conservative_projection.py b/gusto/core/conservative_projection.py new file mode 100644 index 000000000..1ab1455d6 --- /dev/null +++ b/gusto/core/conservative_projection.py @@ -0,0 +1,93 @@ +""" +This provides an operator for perform a conservative projection. + +The :class:`ConservativeProjector` provided in this module is an operator that +projects a field such as a mixing ratio from one function space to another, +weighted by a density field to ensure that mass is conserved by the projection. +""" + +from firedrake import (Function, TestFunction, TrialFunction, lhs, rhs, inner, + dx, LinearVariationalProblem, LinearVariationalSolver, + Constant, assemble) +import ufl + +__all__ = ["ConservativeProjector"] + + +class ConservativeProjector(object): + """ + Projects a field such that mass is conserved. + + This object is designed for projecting fields such as mixing ratios of + tracer species from one function space to another, but weighted by density + such that mass is conserved by the projection. + """ + + def __init__(self, rho_source, rho_target, m_source, m_target, + subtract_mean=False): + """ + Args: + rho_source (:class:`Function`): the density to use for weighting the + source mixing ratio field. Can also be a :class:`ufl.Expr`. + rho_target (:class:`Function`): the density to use for weighting the + target mixing ratio field. Can also be a :class:`ufl.Expr`. + m_source (:class:`Function`): the source mixing ratio field. Can + also be a :class:`ufl.Expr`. + m_target (:class:`Function`): the target mixing ratio field to + compute. + subtract_mean (bool, optional): whether to solve the projection by + subtracting the mean value of m for both sides. This is more + expensive as it involves calculating the mean, but will ensure + preservation of a constant when projecting to a continuous + space. Default to False. + + Raises: + RuntimeError: the geometric shape of the two rho fields must be equal. + RuntimeError: the geometric shape of the two m fields must be equal. + """ + + self.subtract_mean = subtract_mean + + if not isinstance(rho_source, (ufl.core.expr.Expr, Function)): + raise ValueError("Can only recover UFL expression or Functions not '%s'" % type(rho_source)) + + if not isinstance(rho_target, (ufl.core.expr.Expr, Function)): + raise ValueError("Can only recover UFL expression or Functions not '%s'" % type(rho_target)) + + if not isinstance(m_source, (ufl.core.expr.Expr, Function)): + raise ValueError("Can only recover UFL expression or Functions not '%s'" % type(m_source)) + + # Check shape values + if m_source.ufl_shape != m_target.ufl_shape: + raise RuntimeError('Shape mismatch between source %s and target function spaces %s in project' % (m_source.ufl_shape, m_target.ufl_shape)) + + if rho_source.ufl_shape != rho_target.ufl_shape: + raise RuntimeError('Shape mismatch between source %s and target function spaces %s in project' % (rho_source.ufl_shape, rho_target.ufl_shape)) + + self.m_source = m_source + self.m_target = m_target + + V = self.m_target.function_space() + mesh = V.mesh() + + self.m_mean = Constant(0.0, domain=mesh) + self.volume = assemble(Constant(1.0, domain=mesh)*dx) + + test = TestFunction(V) + m_trial = TrialFunction(V) + eqn = (rho_source*inner(test, m_source - self.m_mean)*dx + - rho_target*inner(test, m_trial - self.m_mean)*dx) + problem = LinearVariationalProblem(lhs(eqn), rhs(eqn), self.m_target) + self.solver = LinearVariationalSolver(problem) + + def project(self): + """Apply the projection.""" + + # Compute mean value + if self.subtract_mean: + self.m_mean.assign(assemble(self.m_source*dx) / self.volume) + + # Solve projection + self.solver.solve() + + return self.m_target diff --git a/gusto/equations/prognostic_equations.py b/gusto/equations/prognostic_equations.py index c8c708b3c..7369f5ee2 100644 --- a/gusto/equations/prognostic_equations.py +++ b/gusto/equations/prognostic_equations.py @@ -305,6 +305,17 @@ def add_tracers_to_prognostics(self, domain, active_tracers): name of the active tracer. """ + # Check if there are any conservatively transported tracers. + # If so, ensure that the reference density is indexed before this tracer. + for i in range(len(active_tracers) - 1): + tracer = active_tracers[i] + if tracer.transport_eqn == TransportEquationType.tracer_conservative: + ref_density = next(x for x in active_tracers if x.name == tracer.density_name) + j = active_tracers.index(ref_density) + if j > i: + # Swap the indices of the tracer and the reference density + active_tracers[i], active_tracers[j] = active_tracers[j], active_tracers[i] + # Loop through tracer fields and add field names and spaces for tracer in active_tracers: if isinstance(tracer, ActiveTracer): diff --git a/gusto/recovery/reversible_recovery.py b/gusto/recovery/reversible_recovery.py index d9ad661d1..fc8fbd332 100644 --- a/gusto/recovery/reversible_recovery.py +++ b/gusto/recovery/reversible_recovery.py @@ -3,9 +3,12 @@ higher-order function space. """ +from gusto.core.conservative_projection import ConservativeProjector from firedrake import (Projector, Function, Interpolator) from .recovery import Recoverer +__all__ = ["ReversibleRecoverer", "ConservativeRecoverer"] + class ReversibleRecoverer(object): """ @@ -13,10 +16,11 @@ class ReversibleRecoverer(object): field into a higher-order discontinuous space. This uses the recovery operator, but with further adjustments to ensure reversibility. - :arg source_field: the source field. - :arg target_field: the target_field. - :arg reconstruct_opts: an object containing the various options for the - reconstruction. + Args: + source_field (:class:`Function`): the source field. + target_field (:class:`Function`): the target field. + reconstruct_opts (:class:`RecoveryOptions`): an object containing the + various options for the reconstruction. """ def __init__(self, source_field, target_field, reconstruct_opts): @@ -92,3 +96,63 @@ def project(self): self.q_corr_low.assign(self.q_low - self.q_corr_low) self.injector.interpolate() if self.interp_inj else self.injector.project() self.q_high.assign(self.q_corr_high + self.q_rec_high) + + +class ConservativeRecoverer(object): + """ + An object for performing a reconstruction of a low-order discontinuous + field into a higher-order discontinuous space, but such that mass is + conserved. This uses the recovery operator, but with further adjustments to + ensure both reversibility and mass conservation. + + Args: + source_field (:class:`Function`): the source field. + target_field (:class:`Function`): the target field. + source_density (:class:`Function`): the source density field. + target_density (:class:`Function`): the target density field. + reconstruct_opts (:class:`RecoveryOptions`): an object containing the + various options for the reconstruction. + """ + def __init__(self, source_field, target_field, source_density, + target_density, reconstruct_opts): + + self.opts = reconstruct_opts + + # Declare the fields used by the reconstructor + self.q_low = source_field + self.q_high = target_field + self.q_recovered = Function(self.opts.recovered_space) + self.q_corr_low = Function(source_field.function_space()) + self.q_corr_high = Function(target_field.function_space()) + self.q_rec_high = Function(target_field.function_space()) + + # -------------------------------------------------------------------- # + # Set up the operators for different transformations + # -------------------------------------------------------------------- # + + # Does recovery by first projecting into broken space then averaging + self.recoverer = Recoverer(self.q_low, self.q_recovered, + method=self.opts.broken_method, + boundary_method=self.opts.boundary_method) + + # Obtain the recovered field in the higher order space + self.projector_high = Projector(self.q_recovered, self.q_rec_high) + + # Obtain the correction in the lower order space + # Swap density arguments! + self.projector_low = ConservativeProjector(target_density, source_density, + self.q_rec_high, self.q_corr_low, + subtract_mean=True) + + # Final injection operator + # Should identify low order field in higher order space + self.injector = ConservativeProjector(source_density, target_density, + self.q_corr_low, self.q_corr_high) + + def project(self): + self.recoverer.project() + self.projector_high.project() + self.projector_low.project() + self.q_corr_low.assign(self.q_low - self.q_corr_low) + self.injector.project() + self.q_high.assign(self.q_corr_high + self.q_rec_high) diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index 50243e310..df108a615 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -94,7 +94,8 @@ def __init__(self, domain, field_name=None, solver_parameters=None, 'Time discretisation: suboption SUPG is currently not implemented within MixedOptions') else: raise RuntimeError( - f'Time discretisation: suboption wrapper {self.wrapper_name} not implemented') + f'Time discretisation: suboption wrapper {suboption.name} not implemented') + elif self.wrapper_name == "embedded_dg": self.wrapper = EmbeddedDGWrapper(self, options) elif self.wrapper_name == "recovered": diff --git a/gusto/time_discretisation/wrappers.py b/gusto/time_discretisation/wrappers.py index f37e402c9..a1803cad0 100644 --- a/gusto/time_discretisation/wrappers.py +++ b/gusto/time_discretisation/wrappers.py @@ -11,8 +11,9 @@ ) from firedrake.fml import Term from gusto.core.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions -from gusto.recovery import Recoverer, ReversibleRecoverer +from gusto.recovery import Recoverer, ReversibleRecoverer, ConservativeRecoverer from gusto.core.labels import transporting_velocity +from gusto.core.conservative_projection import ConservativeProjector import ufl __all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedFSWrapper"] @@ -34,6 +35,7 @@ def __init__(self, time_discretisation, wrapper_options): self.options = wrapper_options self.solver_parameters = None self.original_space = None + self.is_conservative = False @abstractmethod def setup(self, original_space): @@ -123,6 +125,7 @@ def setup(self, original_space, post_apply_bcs): self.x_in = Function(self.function_space) self.x_out = Function(self.function_space) + self.x_in_orig = Function(original_space) if self.time_discretisation.idx is None: self.x_projected = Function(self.original_space) @@ -134,6 +137,19 @@ def setup(self, original_space, post_apply_bcs): bcs=post_apply_bcs) elif self.options.project_back_method == 'recover': self.x_out_projector = Recoverer(self.x_out, self.x_projected) + elif self.options.project_back_method == 'conservative_project': + self.is_conservative = True + self.rho_name = self.options.rho_name + self.rho_in_orig = Function(self.options.orig_rho_space) + self.rho_out_orig = Function(self.options.orig_rho_space) + self.rho_in_embedded = Function(self.function_space) + self.rho_out_embedded = Function(self.function_space) + self.x_in_projector = ConservativeProjector( + self.rho_in_orig, self.rho_in_embedded, + self.x_in_orig, self.x_in) + self.x_out_projector = ConservativeProjector( + self.rho_out_embedded, self.rho_out_orig, + self.x_out, self.x_projected, subtract_mean=True) else: raise NotImplementedError( 'EmbeddedDG Wrapper: project_back_method' @@ -152,10 +168,15 @@ def pre_apply(self, x_in): x_in (:class:`Function`): the original input field. """ - try: - self.x_in.interpolate(x_in) - except NotImplementedError: - self.x_in.project(x_in) + self.x_in_orig.assign(x_in) + + if self.is_conservative: + self.x_in_projector.project() + else: + try: + self.x_in.interpolate(x_in) + except NotImplementedError: + self.x_in.project(x_in) def post_apply(self, x_out): """ @@ -215,7 +236,7 @@ def setup(self, original_space, post_apply_bcs): # Internal variables to be used # -------------------------------------------------------------------- # - self.x_in_tmp = Function(self.original_space) + self.x_in_orig = Function(self.original_space) self.x_in = Function(self.function_space) self.x_out = Function(self.function_space) @@ -225,7 +246,19 @@ def setup(self, original_space, post_apply_bcs): self.x_projected = Function(equation.spaces[self.time_discretisation.idx]) # Operator to recover to higher discontinuous space - self.x_recoverer = ReversibleRecoverer(self.x_in_tmp, self.x_in, self.options) + if self.options.project_low_method == 'conservative_project': + self.is_conservative = True + self.rho_name = self.options.rho_name + self.rho_in_orig = Function(self.options.orig_rho_space) + self.rho_out_orig = Function(self.options.orig_rho_space) + self.rho_in_embedded = Function(self.function_space) + self.rho_out_embedded = Function(self.function_space) + self.x_recoverer = ConservativeRecoverer(self.x_in_orig, self.x_in, + self.rho_in_orig, + self.rho_in_embedded, + self.options) + else: + self.x_recoverer = ReversibleRecoverer(self.x_in_orig, self.x_in, self.options) # Operators for projecting back self.interp_back = (self.options.project_low_method == 'interpolate') @@ -237,6 +270,10 @@ def setup(self, original_space, post_apply_bcs): elif self.options.project_low_method == 'recover': self.x_out_projector = Recoverer(self.x_out, self.x_projected, method=self.options.broken_method) + elif self.options.project_low_method == 'conservative_project': + self.x_out_projector = ConservativeProjector( + self.rho_out_embedded, self.rho_out_orig, + self.x_out, self.x_projected, subtract_mean=True) else: raise NotImplementedError( 'Recovery Wrapper: project_back_method' @@ -251,7 +288,7 @@ def pre_apply(self, x_in): x_in (:class:`Function`): the original input field. """ - self.x_in_tmp.assign(x_in) + self.x_in_orig.assign(x_in) self.x_recoverer.project() def post_apply(self, x_out): @@ -416,6 +453,7 @@ def setup(self): self.function_space = MixedFunctionSpace(self.wrapper_spaces) self.x_in = Function(self.function_space) self.x_out = Function(self.function_space) + self.is_conservative = any([subwrapper.is_conservative for subwrapper in self.subwrappers.values()]) def pre_apply(self, x_in): """ @@ -430,6 +468,8 @@ def pre_apply(self, x_in): if field_name in self.subwrappers: subwrapper = self.subwrappers[field_name] + if subwrapper.is_conservative: + self.pre_update_rho(subwrapper) subwrapper.pre_apply(field) x_in_sub.assign(subwrapper.x_in) else: @@ -449,6 +489,34 @@ def post_apply(self, x_out): if field_name in self.subwrappers: subwrapper = self.subwrappers[field_name] subwrapper.x_out.assign(field) + if subwrapper.is_conservative: + self.post_update_rho(subwrapper) subwrapper.post_apply(x_out_sub) else: x_out_sub.assign(field) + + def pre_update_rho(self, subwrapper): + """ + Updates the stored density field for the pre-apply for the subwrapper. + + Args: + subwrapper (:class:`Wrapper`): the original input field. + """ + + rho_subwrapper = self.subwrappers[subwrapper.rho_name] + + subwrapper.rho_in_orig.assign(rho_subwrapper.x_in_orig) + subwrapper.rho_in_embedded.assign(rho_subwrapper.x_in) + + def post_update_rho(self, subwrapper): + """ + Updates the stored density field for the post-apply for the subwrapper. + + Args: + subwrapper (:class:`Wrapper`): the original input field. + """ + + rho_subwrapper = self.subwrappers[subwrapper.rho_name] + + subwrapper.rho_out_orig.assign(rho_subwrapper.x_projected) + subwrapper.rho_out_embedded.assign(rho_subwrapper.x_out) diff --git a/integration-tests/transport/test_tracer_conservative_transport.py b/integration-tests/transport/test_tracer_conservative_transport.py new file mode 100644 index 000000000..bc12dca8d --- /dev/null +++ b/integration-tests/transport/test_tracer_conservative_transport.py @@ -0,0 +1,208 @@ +""" +Tests the conservative transport of a mixing ratio and dry density, both when +they are defined on the same and different function spaces. This checks +that there is conservation of the total species mass (dry density times the +mixing ratio) and that there is consistency (a constant field will remain +constant). +""" + +from gusto import * +from firedrake import ( + PeriodicIntervalMesh, ExtrudedMesh, exp, cos, sin, SpatialCoordinate, + assemble, dx, FunctionSpace, pi, min_value, as_vector, BrokenElement, + errornorm +) +import pytest + + +def setup_conservative_transport(dirname, pair_of_spaces, desirable_property): + + # Domain + Lx = 2000. + Hz = 2000. + + # Time parameters + dt = 2. + tmax = 2000. + + nlayers = 10. # horizontal layers + columns = 10. # number of columns + + # Define the spaces for the tracers + if pair_of_spaces == 'same_order_1': + rho_d_space = 'DG' + m_X_space = 'DG' + space_order = 1 + elif pair_of_spaces == 'diff_order_0': + rho_d_space = 'DG' + m_X_space = 'theta' + space_order = 0 + elif pair_of_spaces == 'diff_order_1': + rho_d_space = 'DG' + m_X_space = 'theta' + space_order = 1 + + period_mesh = PeriodicIntervalMesh(columns, Lx) + mesh = ExtrudedMesh(period_mesh, layers=nlayers, layer_height=Hz/nlayers) + domain = Domain(mesh, dt, "CG", space_order) + x, z = SpatialCoordinate(mesh) + + V_rho = domain.spaces(rho_d_space) + V_m_X = domain.spaces(m_X_space) + + m_X = ActiveTracer(name='m_X', space=m_X_space, + variable_type=TracerVariableType.mixing_ratio, + transport_eqn=TransportEquationType.tracer_conservative, + density_name='rho_d') + + rho_d = ActiveTracer(name='rho_d', space=rho_d_space, + variable_type=TracerVariableType.density, + transport_eqn=TransportEquationType.conservative) + + # Define m_X first to test that the tracers will be + # automatically re-ordered such that the density field + # is indexed before the mixing ratio. + tracers = [m_X, rho_d] + + # Equation + V = domain.spaces("HDiv") + eqn = CoupledTransportEquation(domain, active_tracers=tracers, Vu=V) + + # IO + output = OutputParameters(dirname=dirname) + io = IO(domain, output) + + if pair_of_spaces == 'diff_order_0': + VCG1 = FunctionSpace(mesh, 'CG', 1) + VDG1 = domain.spaces('DG1_equispaced') + + suboptions = { + 'rho_d': RecoveryOptions( + embedding_space=VDG1, + recovered_space=VCG1, + project_low_method='recover', + boundary_method=BoundaryMethod.taylor + ), + 'm_X': ConservativeRecoveryOptions( + embedding_space=VDG1, + recovered_space=VCG1, + boundary_method=BoundaryMethod.taylor, + rho_name='rho_d', + orig_rho_space=V_rho + ) + } + elif pair_of_spaces == 'diff_order_1': + Vt_brok = FunctionSpace(mesh, BrokenElement(V_m_X.ufl_element())) + suboptions = { + 'rho_d': EmbeddedDGOptions(embedding_space=Vt_brok), + 'm_X': ConservativeEmbeddedDGOptions( + rho_name='rho_d', + orig_rho_space=V_rho + ) + } + else: + suboptions = {} + + opts = MixedFSOptions(suboptions=suboptions) + + transport_scheme = SSPRK3( + domain, options=opts, rk_formulation=RungeKuttaFormulation.predictor + ) + transport_methods = [DGUpwind(eqn, "m_X"), DGUpwind(eqn, "rho_d")] + + # Timestepper + time_varying = True + stepper = PrescribedTransport( + eqn, transport_scheme, io, time_varying, transport_methods + ) + + # Initial Conditions + # Specify locations of the two Gaussians + xc1 = 5.*Lx/8. + zc1 = Hz/2. + + xc2 = 3.*Lx/8. + zc2 = Hz/2. + + def l2_dist(xc, zc): + return min_value(abs(x-xc), Lx-abs(x-xc))**2 + (z-zc)**2 + + lc = 2.*Lx/25. + m0 = 0.02 + + # Set the initial state from the configuration choice + if desirable_property == 'conservation': + f0 = 0.05 + + rho_t = 0.5 + rho_b = 1. + + rho_d_0 = rho_b + z*(rho_t-rho_b)/Hz + + g1 = f0*exp(-l2_dist(xc1, zc1)/(lc**2)) + g2 = f0*exp(-l2_dist(xc2, zc2)/(lc**2)) + + m_X_0 = m0 + g1 + g2 + + else: + f0 = 0.5 + rho_b = 0.5 + + g1 = f0*exp(-l2_dist(xc1, zc1)/(lc**2)) + g2 = f0*exp(-l2_dist(xc2, zc2)/(lc**2)) + + rho_d_0 = rho_b + g1 + g2 + + # Constant mass field + m_X_0 = m0 + 0*x + + # Set up the divergent, time-varying, velocity field + U = Lx/tmax + W = U/10. + + def u_t(t): + xd = x - U*t + u = U - (W*pi*Lx/Hz)*cos(pi*t/tmax)*cos(2*pi*xd/Lx)*cos(pi*z/Hz) + w = 2*pi*W*cos(pi*t/tmax)*sin(2*pi*xd/Lx)*sin(pi*z/Hz) + + u_expr = as_vector((u, w)) + + return u_expr + + stepper.setup_prescribed_expr(u_t) + + stepper.fields("m_X").interpolate(m_X_0) + stepper.fields("rho_d").interpolate(rho_d_0) + stepper.fields("u").project(u_t(0)) + + m_X_init = Function(V_m_X) + rho_d_init = Function(V_rho) + + m_X_init.assign(stepper.fields("m_X")) + rho_d_init.assign(stepper.fields("rho_d")) + + return stepper, m_X_init, rho_d_init + + +@pytest.mark.parametrize("pair_of_spaces", ["same_order_1", "diff_order_0", "diff_order_1"]) +@pytest.mark.parametrize("desirable_property", ["consistency", "conservation"]) +def test_conservative_transport(tmpdir, pair_of_spaces, desirable_property): + + # Setup and run + dirname = str(tmpdir) + + stepper, m_X_0, rho_d_0 = \ + setup_conservative_transport(dirname, pair_of_spaces, desirable_property) + + # Run for five timesteps + stepper.run(t=0, tmax=10) + m_X = stepper.fields("m_X") + rho_d = stepper.fields("rho_d") + + # Perform the check + if desirable_property == 'consistency': + assert errornorm(m_X_0, m_X) < 2e-13, "conservative transport is not consistent" + else: + rho_X_init = assemble(m_X_0*rho_d_0*dx) + rho_X_final = assemble(m_X*rho_d*dx) + assert abs((rho_X_init - rho_X_final)/rho_X_init) < 1e-14, "conservative transport is not conservative" diff --git a/unit-tests/recovery_tests/test_conservative_recovery.py b/unit-tests/recovery_tests/test_conservative_recovery.py new file mode 100644 index 000000000..796e75397 --- /dev/null +++ b/unit-tests/recovery_tests/test_conservative_recovery.py @@ -0,0 +1,112 @@ +""" +Test whether the conservative recovery process is working appropriately. +""" + +from firedrake import (PeriodicIntervalMesh, IntervalMesh, ExtrudedMesh, + SpatialCoordinate, FiniteElement, FunctionSpace, + TensorProductElement, Function, interval, norm, errornorm, + assemble) +from gusto import * +import numpy as np +import pytest + +np.random.seed(0) + + +@pytest.fixture +def mesh(geometry): + + L = 100. + H = 100. + + deltax = L / 5. + deltaz = H / 5. + nlayers = int(H/deltaz) + ncolumns = int(L/deltax) + + if geometry == "periodic": + m = PeriodicIntervalMesh(ncolumns, L) + elif geometry == "non-periodic": + m = IntervalMesh(ncolumns, L) + + extruded_mesh = ExtrudedMesh(m, layers=nlayers, layer_height=deltaz) + + return extruded_mesh + + +def expr(geometry, mesh, configuration): + + x, z = SpatialCoordinate(mesh) + + if configuration == 'rho_constant': + rho_expr = Constant(2.0) + if geometry == "periodic": + m_expr = np.random.randn() + np.random.randn() * z + elif geometry == "non-periodic": + m_expr = np.random.randn() + np.random.randn() * x + np.random.randn() * z + + elif configuration == 'm_constant': + m_expr = Constant(0.01) + if geometry == "periodic": + rho_expr = np.random.randn() + np.random.randn() * z + elif geometry == "non-periodic": + rho_expr = np.random.randn() + np.random.randn() * x + np.random.randn() * z + + return rho_expr, m_expr + + +@pytest.mark.parametrize("configuration", ["m_constant", "rho_constant"]) +@pytest.mark.parametrize("geometry", ["periodic", "non-periodic"]) +def test_conservative_recovery(geometry, mesh, configuration): + + rho_expr, m_expr = expr(geometry, mesh, configuration) + + # construct theta elemnt + cell = mesh._base_mesh.ufl_cell().cellname() + w_hori = FiniteElement("DG", cell, 0) + w_vert = FiniteElement("CG", interval, 1) + theta_element = TensorProductElement(w_hori, w_vert) + + # spaces + DG0 = FunctionSpace(mesh, "DG", 0) + CG1 = FunctionSpace(mesh, "CG", 1) + DG1 = FunctionSpace(mesh, "DG", 1) + Vt = FunctionSpace(mesh, theta_element) + + # set up density + rho_DG1 = Function(DG1).interpolate(rho_expr) + rho_DG0 = Function(DG0).project(rho_DG1) + + # mixing ratio fields + m_Vt = Function(Vt).interpolate(m_expr) + m_DG1_approx = Function(DG1).interpolate(m_expr) + m_Vt_back = Function(Vt) + m_DG1 = Function(DG1) + + options = ConservativeRecoveryOptions(embedding_space=DG1, + recovered_space=CG1, + boundary_method=BoundaryMethod.taylor) + + # make the recoverers and do the recovery + conservative_recoverer = ConservativeRecoverer(m_Vt, m_DG1, + rho_DG0, rho_DG1, options) + back_projector = ConservativeProjector(rho_DG1, rho_DG0, m_DG1, m_Vt_back, + subtract_mean=True) + + conservative_recoverer.project() + back_projector.project() + + # check various aspects of the process + m_high_diff = errornorm(m_DG1, m_DG1_approx) / norm(m_DG1_approx) + m_low_diff = errornorm(m_Vt_back, m_Vt) / norm(m_Vt) + mass_low = assemble(rho_DG0*m_Vt*dx) + mass_high = assemble(rho_DG1*m_DG1*dx) + + assert (mass_low - mass_high) / mass_high < 5e-14, \ + f'Conservative recovery on {geometry} vertical slice not conservative for {configuration} configuration' + assert m_low_diff < 2e-14, \ + f'Irreversible conservative recovery on {geometry} vertical slice for {configuration} configuration' + + if configuration in ['m_constant', 'rho_constant']: + assert m_high_diff < 2e-14, \ + f'Inaccurate conservative recovery on {geometry} vertical slice for {configuration} configuration' diff --git a/unit-tests/recovery_tests/test_reversible_recovery.py b/unit-tests/recovery_tests/test_reversible_recovery.py index 721b0158f..474e9ba0f 100644 --- a/unit-tests/recovery_tests/test_reversible_recovery.py +++ b/unit-tests/recovery_tests/test_reversible_recovery.py @@ -9,8 +9,8 @@ """ from firedrake import (IntervalMesh, CubedSphereMesh, IcosahedralSphereMesh, - SpatialCoordinate, FunctionSpace, - Function, norm, errornorm, as_vector) + SpatialCoordinate, FunctionSpace, Interpolator, + Projector, Function, norm, errornorm, as_vector) from gusto import * import numpy as np import pytest diff --git a/unit-tests/test_conservative_projection.py b/unit-tests/test_conservative_projection.py new file mode 100644 index 000000000..bb47c891e --- /dev/null +++ b/unit-tests/test_conservative_projection.py @@ -0,0 +1,64 @@ +""" +This tests the ConservativeProjector object, by projecting a mixing ratio from +DG1 to DG0, relative to different density fields, and checking that the tracer +mass is conserved. +""" + +from firedrake import (UnitSquareMesh, FunctionSpace, Constant, + Function, assemble, dx, sin, SpatialCoordinate) +from gusto import ConservativeProjector +import pytest + + +@pytest.mark.parametrize("projection", ["discontinuous", "continuous"]) +def test_conservative_projection(projection): + + # Set up mesh on plane + mesh = UnitSquareMesh(3, 3) + + # Function spaces and functions + DG0 = FunctionSpace(mesh, "DG", 0) + DG1 = FunctionSpace(mesh, "DG", 1) + + rho_DG0 = Function(DG0) + rho_DG1 = Function(DG1) + m_DG1 = Function(DG1) + + if projection == "continuous": + CG1 = FunctionSpace(mesh, "CG", 1) + m_CG1 = Function(CG1) + else: + m_DG0 = Function(DG0) + + # Projector object + if projection == "continuous": + projector = ConservativeProjector(rho_DG1, rho_DG0, m_DG1, m_CG1, + subtract_mean=True) + else: + projector = ConservativeProjector(rho_DG1, rho_DG0, m_DG1, m_DG0) + + # Initial conditions + x, y = SpatialCoordinate(mesh) + + rho_expr = Constant(1.0) + 0.5*x*y**2 + m_expr = Constant(2.0) + 0.6*sin(x) + + rho_DG1.interpolate(rho_expr) + m_DG1.interpolate(m_expr) + rho_DG0.project(rho_DG1) + + # Test projection + projector.project() + + tol = 1e-14 + mass_DG1 = assemble(rho_DG1*m_DG1*dx) + + if projection == "continuous": + mass_CG1 = assemble(rho_DG0*m_CG1*dx) + + assert abs(mass_CG1 - mass_DG1) < tol, "continuous projection is not conservative" + + else: + mass_DG0 = assemble(rho_DG0*m_DG0*dx) + + assert abs(mass_DG0 - mass_DG1) < tol, "discontinuous projection is not conservative" From 1481ba863a4833f7c503ec7a4dbaf774380d8309 Mon Sep 17 00:00:00 2001 From: Thomas Bendall Date: Fri, 1 Nov 2024 15:59:21 +0000 Subject: [PATCH 2/4] Allow reference profiles to be updated (#542) Co-authored-by: Dr Jemma Shipton --- gusto/core/io.py | 67 ++++++++++++++++--- gusto/solvers/linear_solvers.py | 26 ++++--- .../semi_implicit_quasi_newton.py | 59 ++++++++++++++-- gusto/timestepping/timestepper.py | 24 +++++-- integration-tests/model/test_checkpointing.py | 18 ++--- 5 files changed, 156 insertions(+), 38 deletions(-) diff --git a/gusto/core/io.py b/gusto/core/io.py index a9a09f1fe..6f8f74712 100644 --- a/gusto/core/io.py +++ b/gusto/core/io.py @@ -13,14 +13,22 @@ from pyop2.mpi import MPI import numpy as np from gusto.core.logging import logger, update_logfile_location +from collections import namedtuple -__all__ = ["pick_up_mesh", "IO"] +__all__ = ["pick_up_mesh", "IO", "TimeData"] class GustoIOError(IOError): pass +# A named tuple object encapsulating data about timing +TimeData = namedtuple( + 'TimeData', + ['t', 'step', 'initial_steps', 'last_ref_update_time'] +) + + def pick_up_mesh(output, mesh_name): """ Picks up a checkpointed mesh. This must be the first step of any model being @@ -531,7 +539,14 @@ def setup_dump(self, state_fields, t, pick_up=False): # dump initial fields if not pick_up: - self.dump(state_fields, t, step=1) + step = 1 + last_ref_update_time = None + initial_steps = None + time_data = TimeData( + t=t, step=step, initial_steps=initial_steps, + last_ref_update_time=last_ref_update_time + ) + self.dump(state_fields, time_data) def pick_up_from_checkpoint(self, state_fields): """ @@ -541,7 +556,10 @@ def pick_up_from_checkpoint(self, state_fields): state_fields (:class:`StateFields`): the model's field container. Returns: - float: the checkpointed model time. + tuple of (`time_data`, `reference_profiles`): where `time_data` + itself is a named tuple containing the timing data. + The `reference_profiles` are a list of (`field_name`, expr) + pairs describing the reference profile fields. """ # -------------------------------------------------------------------- # @@ -602,6 +620,13 @@ def pick_up_from_checkpoint(self, state_fields): except AttributeError: initial_steps = None + # Try to pick up number last_ref_update_time + # Not compulsory so errors allowed + try: + last_ref_update_time = chk.read_attribute("/", "last_ref_update_time") + except AttributeError: + last_ref_update_time = None + # Finally pick up time and step number t = chk.read_attribute("/", "time") step = chk.read_attribute("/", "step") @@ -632,6 +657,13 @@ def pick_up_from_checkpoint(self, state_fields): else: initial_steps = None + # Try to pick up last reference profile update time + # Not compulsory so errors allowed + if chk.has_attr("/", "last_ref_update_time"): + last_ref_update_time = chk.get_attr("/", "last_ref_update_time") + else: + last_ref_update_time = None + # Finally pick up time t = chk.get_attr("/", "time") step = chk.get_attr("/", "step") @@ -647,9 +679,14 @@ def pick_up_from_checkpoint(self, state_fields): if hasattr(diagnostic_field, "init_field_set"): diagnostic_field.init_field_set = True - return t, reference_profiles, step, initial_steps + time_data = TimeData( + t=t, step=step, initial_steps=initial_steps, + last_ref_update_time=last_ref_update_time + ) + + return time_data, reference_profiles - def dump(self, state_fields, t, step, initial_steps=None): + def dump(self, state_fields, time_data): """ Dumps all of the required model output. @@ -659,12 +696,20 @@ def dump(self, state_fields, t, step, initial_steps=None): Args: state_fields (:class:`StateFields`): the model's field container. - t (float): the simulation's current time. - step (int): the number of time steps. - initial_steps (int, optional): the number of initial time steps - completed by a multi-level time scheme. Defaults to None. + time_data (namedtuple): contains information relating to the time in + the simulation. The tuple is structured as follows: + - t: current time in s + - step: the index of the time step + - initial_steps: number of initial time steps completed by a + multi-level time scheme (could be None) + - last_ref_update_time: the last time in s that the reference + profiles were updated (could be None) """ output = self.output + t = time_data.t + step = time_data.step + initial_steps = time_data.initial_steps + last_ref_update_time = time_data.last_ref_update_time # Diagnostics: # Compute diagnostic fields @@ -688,6 +733,8 @@ def dump(self, state_fields, t, step, initial_steps=None): self.chkpt.write_attribute("/", "step", step) if initial_steps is not None: self.chkpt.write_attribute("/", "initial_steps", initial_steps) + if last_ref_update_time is not None: + self.chkpt.write_attribute("/", "last_ref_update_time", last_ref_update_time) else: with CheckpointFile(self.chkpt_path, 'w') as chk: chk.save_mesh(self.domain.mesh) @@ -697,6 +744,8 @@ def dump(self, state_fields, t, step, initial_steps=None): chk.set_attr("/", "step", step) if initial_steps is not None: chk.set_attr("/", "initial_steps", initial_steps) + if last_ref_update_time is not None: + chk.set_attr("/", "last_ref_update_time", last_ref_update_time) if (next(self.dumpcount) % output.dumpfreq) == 0: if output.dump_nc: diff --git a/gusto/solvers/linear_solvers.py b/gusto/solvers/linear_solvers.py index 8214d2ba2..cdbb6cc2a 100644 --- a/gusto/solvers/linear_solvers.py +++ b/gusto/solvers/linear_solvers.py @@ -27,7 +27,8 @@ from abc import ABCMeta, abstractmethod, abstractproperty -__all__ = ["BoussinesqSolver", "LinearTimesteppingSolver", "CompressibleSolver", "ThermalSWSolver", "MoistConvectiveSWSolver"] +__all__ = ["BoussinesqSolver", "LinearTimesteppingSolver", "CompressibleSolver", + "ThermalSWSolver", "MoistConvectiveSWSolver"] class TimesteppingSolver(object, metaclass=ABCMeta): @@ -374,6 +375,20 @@ def L_tr(f): python_context = self.hybridized_solver.snes.ksp.pc.getPythonContext() attach_custom_monitor(python_context, logging_ksp_monitor_true_residual) + @timed_function("Gusto:UpdateReferenceProfiles") + def update_reference_profiles(self): + """ + Updates the reference profiles. + """ + + with timed_region("Gusto:HybridProjectRhobar"): + logger.info('Compressible linear solver: rho average solve') + self.rho_avg_solver.solve() + + with timed_region("Gusto:HybridProjectExnerbar"): + logger.info('Compressible linear solver: Exner average solve') + self.exner_avg_solver.solve() + @timed_function("Gusto:LinearSolve") def solve(self, xrhs, dy): """ @@ -387,15 +402,6 @@ def solve(self, xrhs, dy): """ self.xrhs.assign(xrhs) - # TODO: can we avoid computing these each time the solver is called? - with timed_region("Gusto:HybridProjectRhobar"): - logger.info('Compressible linear solver: rho average solve') - self.rho_avg_solver.solve() - - with timed_region("Gusto:HybridProjectExnerbar"): - logger.info('Compressible linear solver: Exner average solve') - self.exner_avg_solver.solve() - # Solve the hybridized system logger.info('Compressible linear solver: hybridized solve') self.hybridized_solver.solve() diff --git a/gusto/timestepping/semi_implicit_quasi_newton.py b/gusto/timestepping/semi_implicit_quasi_newton.py index 1e524a100..ce8758943 100644 --- a/gusto/timestepping/semi_implicit_quasi_newton.py +++ b/gusto/timestepping/semi_implicit_quasi_newton.py @@ -35,7 +35,8 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, diffusion_schemes=None, physics_schemes=None, slow_physics_schemes=None, fast_physics_schemes=None, alpha=Constant(0.5), off_centred_u=False, - num_outer=2, num_inner=2, accelerator=False): + num_outer=2, num_inner=2, accelerator=False, + reference_update_freq=None): """ Args: @@ -84,13 +85,24 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, implicit forcing (pressure gradient and Coriolis) terms, and the linear solve. Defaults to 2. Note that default used by the Met Office's ENDGame and GungHo models is 2. - accelerator (bool, optional): Whether to zero non-wind implicit forcings - for transport terms in order to speed up solver convergence + accelerator (bool, optional): Whether to zero non-wind implicit + forcings for transport terms in order to speed up solver + convergence. Defaults to False. + reference_update_freq (float, optional): frequency with which to + update the reference profile with the n-th time level state + fields. This variable corresponds to time in seconds, and + setting this to zero will update the reference profiles every + time step. Setting it to None turns off the update, and + reference profiles will remain at their initial values. + Defaults to None. """ self.num_outer = num_outer self.num_inner = num_inner self.alpha = alpha + self.accelerator = accelerator + self.reference_update_freq = reference_update_freq + self.to_update_ref_profile = False # default is to not offcentre transporting velocity but if it # is offcentred then use the same value as alpha @@ -188,7 +200,6 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, self.linear_solver = linear_solver self.forcing = Forcing(equation_set, self.alpha) self.bcs = equation_set.bcs - self.accelerator = accelerator def _apply_bcs(self): """ @@ -252,6 +263,24 @@ def copy_active_tracers(self, x_in, x_out): for name in self.tracers_to_copy: x_out(name).assign(x_in(name)) + def update_reference_profiles(self): + """ + Updates the reference profiles and if required also updates them in the + linear solver. + """ + + if self.reference_update_freq is not None: + if float(self.t) + self.reference_update_freq > self.last_ref_update_time: + self.equation.X_ref.assign(self.x.n(self.field_name)) + self.last_ref_update_time = float(self.t) + if hasattr(self.linear_solver, 'update_reference_profiles'): + self.linear_solver.update_reference_profiles() + + elif self.to_update_ref_profile: + if hasattr(self.linear_solver, 'update_reference_profiles'): + self.linear_solver.update_reference_profiles() + self.to_update_ref_profile = False + def timestep(self): """Defines the timestep""" xn = self.x.n @@ -264,6 +293,10 @@ def timestep(self): xrhs_phys = self.xrhs_phys dy = self.dy + # Update reference profiles -------------------------------------------- + self.update_reference_profiles() + + # Slow physics --------------------------------------------------------- x_after_slow(self.field_name).assign(xn(self.field_name)) if len(self.slow_physics_schemes) > 0: with timed_stage("Slow physics"): @@ -271,6 +304,7 @@ def timestep(self): for _, scheme in self.slow_physics_schemes: scheme.apply(x_after_slow(scheme.field_name), x_after_slow(scheme.field_name)) + # Explict forcing ------------------------------------------------------ with timed_stage("Apply forcing terms"): logger.info('Semi-implicit Quasi Newton: Explicit forcing') # Put explicit forcing into xstar @@ -280,8 +314,10 @@ def timestep(self): # the correct values xp(self.field_name).assign(xstar(self.field_name)) + # OUTER ---------------------------------------------------------------- for outer in range(self.num_outer): + # Transport -------------------------------------------------------- with timed_stage("Transport"): self.io.log_courant(self.fields, 'transporting_velocity', message=f'transporting velocity, outer iteration {outer}') @@ -290,6 +326,7 @@ def timestep(self): # transports a field from xstar and puts result in xp scheme.apply(xp(name), xstar(name)) + # Fast physics ----------------------------------------------------- x_after_fast(self.field_name).assign(xp(self.field_name)) if len(self.fast_physics_schemes) > 0: with timed_stage("Fast physics"): @@ -302,8 +339,7 @@ def timestep(self): for inner in range(self.num_inner): - # TODO: this is where to update the reference state - + # Implicit forcing --------------------------------------------- with timed_stage("Apply forcing terms"): logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}') self.forcing.apply(xp, xnp1, xrhs, "implicit") @@ -314,6 +350,7 @@ def timestep(self): xrhs -= xnp1(self.field_name) xrhs += xrhs_phys + # Linear solve ------------------------------------------------- with timed_stage("Implicit solve"): logger.info(f'Semi-implicit Quasi Newton: Mixed solve {(outer, inner)}') self.linear_solver.solve(xrhs, dy) # solves linear system and places result in dy @@ -353,10 +390,18 @@ def run(self, t, tmax, pick_up=False): pick_up: (bool): specify whether to pick_up from a previous run """ - if not pick_up: + if not pick_up and self.reference_update_freq is None: assert self.reference_profiles_initialised, \ 'Reference profiles for must be initialised to use Semi-Implicit Timestepper' + if not pick_up and self.reference_update_freq is not None: + # Force reference profiles to be updated on first time step + self.last_ref_update_time = float(t) - float(self.dt) + + elif not pick_up or (pick_up and self.reference_update_freq is None): + # Indicate that linear solver profile needs updating + self.to_update_ref_profile = True + super().run(t, tmax, pick_up=pick_up) diff --git a/gusto/timestepping/timestepper.py b/gusto/timestepping/timestepper.py index bfd45613b..3c619753b 100644 --- a/gusto/timestepping/timestepper.py +++ b/gusto/timestepping/timestepper.py @@ -6,6 +6,7 @@ from pyop2.profiling import timed_stage from gusto.equations import PrognosticEquationSet from gusto.core import TimeLevelFields, StateFields +from gusto.core.io import TimeData from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity from gusto.core.logging import logger from gusto.time_discretisation.time_discretisation import ExplicitTimeDiscretisation @@ -30,6 +31,7 @@ def __init__(self, equation, io): self.dt = self.equation.domain.dt self.t = self.equation.domain.t self.reference_profiles_initialised = False + self.last_ref_update_time = None self.setup_fields() self.setup_scheme() @@ -182,9 +184,14 @@ def run(self, t, tmax, pick_up=False): if pick_up: # Pick up fields, and return other info to be picked up - t, reference_profiles, self.step, initial_timesteps = self.io.pick_up_from_checkpoint(self.fields) - self.set_reference_profiles(reference_profiles) + time_data, reference_profiles = self.io.pick_up_from_checkpoint(self.fields) + t = time_data.t + self.step = time_data.step + initial_timesteps = time_data.initial_steps + last_ref_update_time = time_data.last_ref_update_time + self.set_reference_profiles(reference_profiles, last_ref_update_time) self.set_initial_timesteps(initial_timesteps) + else: self.step = 1 @@ -212,14 +219,19 @@ def run(self, t, tmax, pick_up=False): self.step += 1 with timed_stage("Dump output"): - self.io.dump(self.fields, float(self.t), self.step, self.get_initial_timesteps()) + time_data = TimeData( + t=float(self.t), step=self.step, + initial_steps=self.get_initial_timesteps(), + last_ref_update_time=self.last_ref_update_time + ) + self.io.dump(self.fields, time_data) if self.io.output.checkpoint and self.io.output.checkpoint_method == 'dumbcheckpoint': self.io.chkpt.close() logger.info(f'TIMELOOP complete. t={float(self.t):.5f}, {tmax=:.5f}') - def set_reference_profiles(self, reference_profiles): + def set_reference_profiles(self, reference_profiles, last_ref_update_time=None): """ Initialise the model's reference profiles. @@ -227,6 +239,8 @@ def set_reference_profiles(self, reference_profiles): where 'field_name' is the string giving the name of the reference profile field expr is the :class:`ufl.Expr` whose value is used to set the reference field. + last_ref_update_time (float, optional): the last time that the reference + profiles were updated. Defaults to None. """ for field_name, profile in reference_profiles: if field_name+'_bar' in self.fields: @@ -256,6 +270,8 @@ def set_reference_profiles(self, reference_profiles): # Don't need to do anything else as value in field container has already been set self.reference_profiles_initialised = True + self.last_ref_update_time = last_ref_update_time + class Timestepper(BaseTimestepper): """ diff --git a/integration-tests/model/test_checkpointing.py b/integration-tests/model/test_checkpointing.py index 28549a66f..dc72fb1de 100644 --- a/integration-tests/model/test_checkpointing.py +++ b/integration-tests/model/test_checkpointing.py @@ -11,7 +11,7 @@ import pytest -def set_up_model_objects(mesh, dt, output, stepper_type): +def set_up_model_objects(mesh, dt, output, stepper_type, ref_update_freq): domain = Domain(mesh, dt, "CG", 1) @@ -40,7 +40,8 @@ def set_up_model_objects(mesh, dt, output, stepper_type): # build time stepper stepper = SemiImplicitQuasiNewton(eqns, io, transported_fields, transport_methods, - linear_solver=linear_solver) + linear_solver=linear_solver, + reference_update_freq=ref_update_freq) elif stepper_type == 'multi_level': scheme = AdamsBashforth(domain, order=2) @@ -92,9 +93,10 @@ def initialise_fields(eqns, stepper): stepper.set_reference_profiles([('rho', rho_b), ('theta', theta_b)]) -@pytest.mark.parametrize("stepper_type", ["multi_level", "semi_implicit"]) +@pytest.mark.parametrize("stepper_type, ref_update_freq", [ + ("multi_level", None), ("semi_implicit", None), ("semi_implicit", 0.6)]) @pytest.mark.parametrize("checkpoint_method", ["dumbcheckpoint", "checkpointfile"]) -def test_checkpointing(tmpdir, stepper_type, checkpoint_method): +def test_checkpointing(tmpdir, stepper_type, checkpoint_method, ref_update_freq): mesh_name = 'checkpointing_mesh' @@ -128,8 +130,8 @@ def test_checkpointing(tmpdir, stepper_type, checkpoint_method): chkptfreq=2, ) - stepper_1, eqns_1 = set_up_model_objects(mesh, dt, output_1, stepper_type) - stepper_2, eqns_2 = set_up_model_objects(mesh, dt, output_2, stepper_type) + stepper_1, eqns_1 = set_up_model_objects(mesh, dt, output_1, stepper_type, ref_update_freq) + stepper_2, eqns_2 = set_up_model_objects(mesh, dt, output_2, stepper_type, ref_update_freq) initialise_fields(eqns_1, stepper_1) initialise_fields(eqns_2, stepper_2) @@ -163,7 +165,7 @@ def test_checkpointing(tmpdir, stepper_type, checkpoint_method): if checkpoint_method == 'checkpointfile': mesh = pick_up_mesh(output_3, mesh_name) - stepper_3, _ = set_up_model_objects(mesh, dt, output_3, stepper_type) + stepper_3, _ = set_up_model_objects(mesh, dt, output_3, stepper_type, ref_update_freq) stepper_3.io.pick_up_from_checkpoint(stepper_3.fields) # ------------------------------------------------------------------------ # @@ -192,7 +194,7 @@ def test_checkpointing(tmpdir, stepper_type, checkpoint_method): ) if checkpoint_method == 'checkpointfile': mesh = pick_up_mesh(output_3, mesh_name) - stepper_3, _ = set_up_model_objects(mesh, dt, output_3, stepper_type) + stepper_3, _ = set_up_model_objects(mesh, dt, output_3, stepper_type, ref_update_freq) stepper_3.run(t=2*dt, tmax=4*dt, pick_up=True) # ------------------------------------------------------------------------ # From 950cc239541fe68b5949a1506834b2758e0b125f Mon Sep 17 00:00:00 2001 From: Thomas Bendall Date: Fri, 1 Nov 2024 16:20:18 +0000 Subject: [PATCH 3/4] Predictors (#551) --- examples/shallow_water/williamson_5.py | 1 + .../semi_implicit_quasi_newton.py | 56 +++++++++++++++++-- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/examples/shallow_water/williamson_5.py b/examples/shallow_water/williamson_5.py index 798aed401..e6cae36c7 100644 --- a/examples/shallow_water/williamson_5.py +++ b/examples/shallow_water/williamson_5.py @@ -52,6 +52,7 @@ def williamson_5( # ------------------------------------------------------------------------ # element_order = 1 + # ------------------------------------------------------------------------ # # Set up model objects # ------------------------------------------------------------------------ # diff --git a/gusto/timestepping/semi_implicit_quasi_newton.py b/gusto/timestepping/semi_implicit_quasi_newton.py index ce8758943..2c49304a3 100644 --- a/gusto/timestepping/semi_implicit_quasi_newton.py +++ b/gusto/timestepping/semi_implicit_quasi_newton.py @@ -3,8 +3,10 @@ and GungHo dynamical cores. """ -from firedrake import (Function, Constant, TrialFunctions, DirichletBC, - LinearVariationalProblem, LinearVariationalSolver) +from firedrake import ( + Function, Constant, TrialFunctions, DirichletBC, div, Interpolator, + LinearVariationalProblem, LinearVariationalSolver +) from firedrake.fml import drop, replace_subject from pyop2.profiling import timed_stage from gusto.core import TimeLevelFields, StateFields @@ -36,8 +38,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, slow_physics_schemes=None, fast_physics_schemes=None, alpha=Constant(0.5), off_centred_u=False, num_outer=2, num_inner=2, accelerator=False, - reference_update_freq=None): - + predictor=None, reference_update_freq=None): """ Args: equation_set (:class:`PrognosticEquationSet`): the prognostic @@ -88,6 +89,15 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, accelerator (bool, optional): Whether to zero non-wind implicit forcings for transport terms in order to speed up solver convergence. Defaults to False. + predictor (str, optional): a single string corresponding to the name + of a variable to transport using the divergence predictor. This + pre-multiplies that variable by (1 - beta*dt*div(u)) before the + transport step, and calculates its transport increment from the + transport of this variable. This can improve the stability of + the time stepper at large time steps, when not using an + advective-then-flux formulation. This is only suitable for the + use on the conservative variable (e.g. depth or density). + Defaults to None, in which case no predictor is used. reference_update_freq (float, optional): frequency with which to update the reference profile with the n-th time level state fields. This variable corresponds to time in seconds, and @@ -100,6 +110,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, self.num_outer = num_outer self.num_inner = num_inner self.alpha = alpha + self.predictor = predictor self.accelerator = accelerator self.reference_update_freq = reference_update_freq self.to_update_ref_profile = False @@ -201,6 +212,14 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, self.forcing = Forcing(equation_set, self.alpha) self.bcs = equation_set.bcs + if self.predictor is not None: + V_DG = equation_set.domain.spaces('DG') + self.predictor_field_in = Function(V_DG) + div_factor = Constant(1.0) - (Constant(1.0) - self.alpha)*self.dt*div(self.x.n('u')) + self.predictor_interpolator = Interpolator( + self.x.star(predictor)*div_factor, self.predictor_field_in + ) + def _apply_bcs(self): """ Set the zero boundary conditions in the velocity. @@ -263,6 +282,33 @@ def copy_active_tracers(self, x_in, x_out): for name in self.tracers_to_copy: x_out(name).assign(x_in(name)) + def transport_field(self, name, scheme, xstar, xp): + """ + Performs the transport of a field in xstar, placing the result in xp. + + Args: + name (str): the name of the field to be transported. + scheme (:class:`TimeDiscretisation`): the time discretisation used + for the transport. + xstar (:class:`Fields`): the collection of state fields to be + transported. + xp (:class:`Fields`): the collection of state fields resulting from + the transport. + """ + + if name == self.predictor: + # Pre-multiply this variable by (1 - dt*beta*div(u)) + V = xstar(name).function_space() + field_out = Function(V) + self.predictor_interpolator.interpolate() + scheme.apply(field_out, self.predictor_field_in) + + # xp is xstar plus the increment from the transported predictor + xp(name).assign(xstar(name) + field_out - self.predictor_field_in) + else: + # Standard transport + scheme.apply(xp(name), xstar(name)) + def update_reference_profiles(self): """ Updates the reference profiles and if required also updates them in the @@ -324,7 +370,7 @@ def timestep(self): for name, scheme in self.active_transport: logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}') # transports a field from xstar and puts result in xp - scheme.apply(xp(name), xstar(name)) + self.transport_field(name, scheme, xstar, xp) # Fast physics ----------------------------------------------------- x_after_fast(self.field_name).assign(xp(self.field_name)) From 6c5e752a2c776824b5fbc76a000790f0f0929512 Mon Sep 17 00:00:00 2001 From: Thomas Bendall Date: Fri, 1 Nov 2024 16:47:12 +0000 Subject: [PATCH 4/4] Make linear continuity form correct for varying linearisation (#558) --- gusto/equations/common_forms.py | 22 +--------------------- gusto/equations/shallow_water_equations.py | 4 ++-- gusto/solvers/linear_solvers.py | 4 ++-- 3 files changed, 5 insertions(+), 25 deletions(-) diff --git a/gusto/equations/common_forms.py b/gusto/equations/common_forms.py index 91ce01368..8bb606511 100644 --- a/gusto/equations/common_forms.py +++ b/gusto/equations/common_forms.py @@ -14,7 +14,6 @@ "kinetic_energy_form", "advection_equation_circulation_form", "diffusion_form", "diffusion_form_1d", "linear_advection_form", "linear_continuity_form", - "linear_continuity_form_1d", "split_continuity_form", "tracer_conservative_form"] @@ -134,26 +133,7 @@ def linear_continuity_form(test, qbar, ubar): :class:`LabelledForm`: a labelled transport form. """ - L = qbar*test*div(ubar)*dx - form = transporting_velocity(L, ubar) - - return transport(form, TransportEquationType.conservative) - - -def linear_continuity_form_1d(test, qbar, ubar): - """ - The form corresponding to the linearised continuity transport operator. - - Args: - test (:class:`TestFunction`): the test function. - qbar (:class:`ufl.Expr`): the variable to be transported. - ubar (:class:`ufl.Expr`): the transporting velocity. - - Returns: - :class:`LabelledForm`: a labelled transport form. - """ - - L = qbar*test*ubar.dx(0)*dx + L = test*div(qbar*ubar)*dx form = transporting_velocity(L, ubar) return transport(form, TransportEquationType.conservative) diff --git a/gusto/equations/shallow_water_equations.py b/gusto/equations/shallow_water_equations.py index cad6e12cb..854266952 100644 --- a/gusto/equations/shallow_water_equations.py +++ b/gusto/equations/shallow_water_equations.py @@ -9,7 +9,7 @@ advection_form, advection_form_1d, continuity_form, continuity_form_1d, vector_invariant_form, kinetic_energy_form, advection_equation_circulation_form, diffusion_form_1d, - linear_continuity_form, linear_continuity_form_1d + linear_continuity_form ) from gusto.equations.prognostic_equations import PrognosticEquationSet @@ -361,7 +361,7 @@ def __init__(self, domain, parameters, # Transport term needs special linearisation if self.linearisation_map(D_adv.terms[0]): - linear_D_adv = linear_continuity_form_1d(phi, H, u_trial) + linear_D_adv = linear_continuity_form(phi, H, u_trial) # Add linearisation to D_adv D_adv = linearisation(D_adv, linear_D_adv) diff --git a/gusto/solvers/linear_solvers.py b/gusto/solvers/linear_solvers.py index cdbb6cc2a..72dcea7a9 100644 --- a/gusto/solvers/linear_solvers.py +++ b/gusto/solvers/linear_solvers.py @@ -666,7 +666,7 @@ def _setup_solver(self): - beta_u * 0.5 * bbar * div(w*(D-Dbar)) * dx + beta_u * 0.5 * jump((D-Dbar)*w, n) * avg(bbar) * dS + inner(phi, (D - D_in)) * dx - + beta_d * phi * Dbar * div(u) * dx + + beta_d * phi * div(Dbar*u) * dx ) if 'coriolis' in equation.prescribed_fields._field_names: @@ -882,7 +882,7 @@ def _setup_solver(self): inner(w, (u - u_in)) * dx - beta_u * (D - Dbar) * div(w*g) * dx + inner(phi, (D - D_in)) * dx - + beta_d * phi * Dbar * div(u) * dx + + beta_d * phi * div(Dbar*u) * dx ) if 'coriolis' in equation.prescribed_fields._field_names: