From 0678bae35ef00249f1fdf6f8bfddbae0afe11e6b Mon Sep 17 00:00:00 2001 From: Tom Bendall Date: Sat, 21 Sep 2024 13:29:06 +0100 Subject: [PATCH] tidy up and add comment --- examples/shallow_water/williamson_5.py | 7 +- .../semi_implicit_quasi_newton.py | 99 ++++++++----------- 2 files changed, 46 insertions(+), 60 deletions(-) diff --git a/examples/shallow_water/williamson_5.py b/examples/shallow_water/williamson_5.py index 5e46f730a..575eeef08 100644 --- a/examples/shallow_water/williamson_5.py +++ b/examples/shallow_water/williamson_5.py @@ -14,7 +14,7 @@ Domain, IO, OutputParameters, SemiImplicitQuasiNewton, SSPRK3, DGUpwind, TrapeziumRule, ShallowWaterParameters, ShallowWaterEquations, Sum, lonlatr_from_xyz, GeneralIcosahedralSphereMesh, ZonalComponent, - MeridionalComponent, RelativeVorticity, + MeridionalComponent, RelativeVorticity ) williamson_5_defaults = { @@ -72,13 +72,12 @@ def williamson_5( rsq = min_value(R0**2, (lamda - lamda_c)**2 + (phi - phi_c)**2) r = sqrt(rsq) tpexpr = mountain_height * (1 - r/R0) - eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr, bexpr=tpexpr, - u_transport_option='vector_advection_form') + eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr, bexpr=tpexpr) # I/O output = OutputParameters( dirname=dirname, dumplist_latlon=['D'], dumpfreq=dumpfreq, - dump_vtus=False, dump_nc=True, dumplist=['D', 'topography'] + dump_vtus=True, dump_nc=False, dumplist=['D', 'topography'] ) diagnostic_fields = [Sum('D', 'topography'), RelativeVorticity(), MeridionalComponent('u'), ZonalComponent('u')] diff --git a/gusto/timestepping/semi_implicit_quasi_newton.py b/gusto/timestepping/semi_implicit_quasi_newton.py index a835d7b66..3a645046c 100644 --- a/gusto/timestepping/semi_implicit_quasi_newton.py +++ b/gusto/timestepping/semi_implicit_quasi_newton.py @@ -3,9 +3,10 @@ and GungHo dynamical cores. """ -from firedrake import (Function, Constant, TrialFunctions, DirichletBC, - LinearVariationalProblem, LinearVariationalSolver, - Interpolator, div) +from firedrake import ( + Function, Constant, TrialFunctions, DirichletBC, div, Interpolator, + LinearVariationalProblem, LinearVariationalSolver +) from firedrake.fml import drop, replace_subject from pyop2.profiling import timed_stage from gusto.core import TimeLevelFields, StateFields @@ -37,7 +38,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, slow_physics_schemes=None, fast_physics_schemes=None, alpha=Constant(0.5), off_centred_u=False, num_outer=2, num_inner=2, accelerator=False, - reference_update_freq=None, predictor=None): + predictor=None): """ Args: @@ -89,21 +90,20 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, accelerator (bool, optional): Whether to zero non-wind implicit forcings for transport terms in order to speed up solver convergence. Defaults to False. - reference_update_freq (float, optional): frequency with which to - update the reference profile with the n-th time level state - fields. This variable corresponds to time in seconds, and - setting this to zero will update the reference profiles every - time step. Setting it to None turns off the update, and - reference profiles will remain at their initial values. - Defaults to None. + predictor (str, optional): a single string corresponding to the name + of a variable to transport using the divergence predictor. This + pre-multiplies that variable by (1 - beta*dt*div(u)) before the + transport step, and calculates its transport increment from the + transport of this variable. This can improve the stability of + the time stepper at large time steps, when not using an + advective-then-flux formulation. This is only suitable for the + use on the conservative variable (e.g. depth or density). + Defaults to None, in which case no predictor is used. """ self.num_outer = num_outer self.num_inner = num_inner self.alpha = alpha - self.accelerator = accelerator - self.reference_update_freq = reference_update_freq - self.to_update_ref_profile = False self.predictor = predictor # default is to not offcentre transporting velocity but if it @@ -202,6 +202,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, self.linear_solver = linear_solver self.forcing = Forcing(equation_set, self.alpha) self.bcs = equation_set.bcs + self.accelerator = accelerator if self.predictor is not None: V_DG = equation_set.domain.spaces('DG') @@ -272,23 +273,33 @@ def copy_active_tracers(self, x_in, x_out): for name in self.tracers_to_copy: x_out(name).assign(x_in(name)) - def update_reference_profiles(self): + def transport_field(self, name, scheme, xstar, xp): """ - Updates the reference profiles and if required also updates them in the - linear solver. + Performs the transport of a field in xstar, placing the result in xp. + + Args: + name (str): the name of the field to be transported. + scheme (:class:`TimeDiscretisation`): the time discretisation used + for the transport. + xstar (:class:`Fields`): the collection of state fields to be + transported. + xp (:class:`Fields`): the collection of state fields resulting from + the transport. """ - if self.reference_update_freq is not None: - if float(self.t) + self.reference_update_freq > self.last_ref_update_time: - self.equation.X_ref.assign(self.x.n(self.field_name)) - self.last_ref_update_time = float(self.t) - if hasattr(self.linear_solver, 'update_reference_profiles'): - self.linear_solver.update_reference_profiles() + if name == self.predictor: + # Pre-multiply this variable by (1 - dt*beta*div(u)) + V = xstar(name).function_space() + field_in = Function(V) + field_out = Function(V) + self.predictor_interpolator.interpolate() + scheme.apply(field_out, field_in) - elif self.to_update_ref_profile: - if hasattr(self.linear_solver, 'update_reference_profiles'): - self.linear_solver.update_reference_profiles() - self.to_update_ref_profile = False + # xp is xstar plus the increment from the transported predictor + xp(name).assign(xstar(name) + field_out - field_in) + else: + # Standard transport + scheme.apply(xp(name), xstar(name)) def timestep(self): """Defines the timestep""" @@ -302,10 +313,6 @@ def timestep(self): xrhs_phys = self.xrhs_phys dy = self.dy - # Update reference profiles -------------------------------------------- - self.update_reference_profiles() - - # Slow physics --------------------------------------------------------- x_after_slow(self.field_name).assign(xn(self.field_name)) if len(self.slow_physics_schemes) > 0: with timed_stage("Slow physics"): @@ -313,7 +320,6 @@ def timestep(self): for _, scheme in self.slow_physics_schemes: scheme.apply(x_after_slow(scheme.field_name), x_after_slow(scheme.field_name)) - # Explict forcing ------------------------------------------------------ with timed_stage("Apply forcing terms"): logger.info('Semi-implicit Quasi Newton: Explicit forcing') # Put explicit forcing into xstar @@ -323,27 +329,16 @@ def timestep(self): # the correct values xp(self.field_name).assign(xstar(self.field_name)) - # OUTER ---------------------------------------------------------------- for outer in range(self.num_outer): - # Transport -------------------------------------------------------- with timed_stage("Transport"): self.io.log_courant(self.fields, 'transporting_velocity', message=f'transporting velocity, outer iteration {outer}') for name, scheme in self.active_transport: logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}') # transports a field from xstar and puts result in xp - if name == self.predictor: - V = xstar(name).function_space() - field_in = Function(V) - field_out = Function(V) - self.predictor_interpolator.interpolate() - scheme.apply(field_out, field_in) - xp(name).assign(xstar(name) + field_out - field_in) - else: - scheme.apply(xp(name), xstar(name)) - - # Fast physics ----------------------------------------------------- + self.transport_field(name, scheme, xstar, xp) + x_after_fast(self.field_name).assign(xp(self.field_name)) if len(self.fast_physics_schemes) > 0: with timed_stage("Fast physics"): @@ -356,7 +351,8 @@ def timestep(self): for inner in range(self.num_inner): - # Implicit forcing --------------------------------------------- + # TODO: this is where to update the reference state + with timed_stage("Apply forcing terms"): logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}') self.forcing.apply(xp, xnp1, xrhs, "implicit") @@ -367,7 +363,6 @@ def timestep(self): xrhs -= xnp1(self.field_name) xrhs += xrhs_phys - # Linear solve ------------------------------------------------- with timed_stage("Implicit solve"): logger.info(f'Semi-implicit Quasi Newton: Mixed solve {(outer, inner)}') self.linear_solver.solve(xrhs, dy) # solves linear system and places result in dy @@ -407,18 +402,10 @@ def run(self, t, tmax, pick_up=False): pick_up: (bool): specify whether to pick_up from a previous run """ - if not pick_up and self.reference_update_freq is None: + if not pick_up: assert self.reference_profiles_initialised, \ 'Reference profiles for must be initialised to use Semi-Implicit Timestepper' - if not pick_up and self.reference_update_freq is not None: - # Force reference profiles to be updated on first time step - self.last_ref_update_time = float(t) - float(self.dt) - - elif not pick_up or (pick_up and self.reference_update_freq is None): - # Indicate that linear solver profile needs updating - self.to_update_ref_profile = True - super().run(t, tmax, pick_up=pick_up)