Skip to content

Commit

Permalink
tidy up and add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
tommbendall committed Sep 21, 2024
1 parent f06e8a8 commit 0678bae
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 60 deletions.
7 changes: 3 additions & 4 deletions examples/shallow_water/williamson_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Domain, IO, OutputParameters, SemiImplicitQuasiNewton, SSPRK3, DGUpwind,
TrapeziumRule, ShallowWaterParameters, ShallowWaterEquations, Sum,
lonlatr_from_xyz, GeneralIcosahedralSphereMesh, ZonalComponent,
MeridionalComponent, RelativeVorticity,
MeridionalComponent, RelativeVorticity
)

williamson_5_defaults = {
Expand Down Expand Up @@ -72,13 +72,12 @@ def williamson_5(
rsq = min_value(R0**2, (lamda - lamda_c)**2 + (phi - phi_c)**2)
r = sqrt(rsq)
tpexpr = mountain_height * (1 - r/R0)
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr, bexpr=tpexpr,
u_transport_option='vector_advection_form')
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr, bexpr=tpexpr)

# I/O
output = OutputParameters(
dirname=dirname, dumplist_latlon=['D'], dumpfreq=dumpfreq,
dump_vtus=False, dump_nc=True, dumplist=['D', 'topography']
dump_vtus=True, dump_nc=False, dumplist=['D', 'topography']
)
diagnostic_fields = [Sum('D', 'topography'), RelativeVorticity(),
MeridionalComponent('u'), ZonalComponent('u')]
Expand Down
99 changes: 43 additions & 56 deletions gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
and GungHo dynamical cores.
"""

from firedrake import (Function, Constant, TrialFunctions, DirichletBC,
LinearVariationalProblem, LinearVariationalSolver,
Interpolator, div)
from firedrake import (
Function, Constant, TrialFunctions, DirichletBC, div, Interpolator,
LinearVariationalProblem, LinearVariationalSolver
)
from firedrake.fml import drop, replace_subject
from pyop2.profiling import timed_stage
from gusto.core import TimeLevelFields, StateFields
Expand Down Expand Up @@ -37,7 +38,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
slow_physics_schemes=None, fast_physics_schemes=None,
alpha=Constant(0.5), off_centred_u=False,
num_outer=2, num_inner=2, accelerator=False,
reference_update_freq=None, predictor=None):
predictor=None):

"""
Args:
Expand Down Expand Up @@ -89,21 +90,20 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
accelerator (bool, optional): Whether to zero non-wind implicit
forcings for transport terms in order to speed up solver
convergence. Defaults to False.
reference_update_freq (float, optional): frequency with which to
update the reference profile with the n-th time level state
fields. This variable corresponds to time in seconds, and
setting this to zero will update the reference profiles every
time step. Setting it to None turns off the update, and
reference profiles will remain at their initial values.
Defaults to None.
predictor (str, optional): a single string corresponding to the name
of a variable to transport using the divergence predictor. This
pre-multiplies that variable by (1 - beta*dt*div(u)) before the
transport step, and calculates its transport increment from the
transport of this variable. This can improve the stability of
the time stepper at large time steps, when not using an
advective-then-flux formulation. This is only suitable for the
use on the conservative variable (e.g. depth or density).
Defaults to None, in which case no predictor is used.
"""

self.num_outer = num_outer
self.num_inner = num_inner
self.alpha = alpha
self.accelerator = accelerator
self.reference_update_freq = reference_update_freq
self.to_update_ref_profile = False
self.predictor = predictor

# default is to not offcentre transporting velocity but if it
Expand Down Expand Up @@ -202,6 +202,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.linear_solver = linear_solver
self.forcing = Forcing(equation_set, self.alpha)
self.bcs = equation_set.bcs
self.accelerator = accelerator

if self.predictor is not None:
V_DG = equation_set.domain.spaces('DG')
Expand Down Expand Up @@ -272,23 +273,33 @@ def copy_active_tracers(self, x_in, x_out):
for name in self.tracers_to_copy:
x_out(name).assign(x_in(name))

def update_reference_profiles(self):
def transport_field(self, name, scheme, xstar, xp):
"""
Updates the reference profiles and if required also updates them in the
linear solver.
Performs the transport of a field in xstar, placing the result in xp.
Args:
name (str): the name of the field to be transported.
scheme (:class:`TimeDiscretisation`): the time discretisation used
for the transport.
xstar (:class:`Fields`): the collection of state fields to be
transported.
xp (:class:`Fields`): the collection of state fields resulting from
the transport.
"""

if self.reference_update_freq is not None:
if float(self.t) + self.reference_update_freq > self.last_ref_update_time:
self.equation.X_ref.assign(self.x.n(self.field_name))
self.last_ref_update_time = float(self.t)
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()
if name == self.predictor:
# Pre-multiply this variable by (1 - dt*beta*div(u))
V = xstar(name).function_space()
field_in = Function(V)
field_out = Function(V)
self.predictor_interpolator.interpolate()
scheme.apply(field_out, field_in)

elif self.to_update_ref_profile:
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()
self.to_update_ref_profile = False
# xp is xstar plus the increment from the transported predictor
xp(name).assign(xstar(name) + field_out - field_in)
else:
# Standard transport
scheme.apply(xp(name), xstar(name))

def timestep(self):
"""Defines the timestep"""
Expand All @@ -302,18 +313,13 @@ def timestep(self):
xrhs_phys = self.xrhs_phys
dy = self.dy

# Update reference profiles --------------------------------------------
self.update_reference_profiles()

# Slow physics ---------------------------------------------------------
x_after_slow(self.field_name).assign(xn(self.field_name))
if len(self.slow_physics_schemes) > 0:
with timed_stage("Slow physics"):
logger.info('Semi-implicit Quasi Newton: Slow physics')
for _, scheme in self.slow_physics_schemes:
scheme.apply(x_after_slow(scheme.field_name), x_after_slow(scheme.field_name))

# Explict forcing ------------------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info('Semi-implicit Quasi Newton: Explicit forcing')
# Put explicit forcing into xstar
Expand All @@ -323,27 +329,16 @@ def timestep(self):
# the correct values
xp(self.field_name).assign(xstar(self.field_name))

# OUTER ----------------------------------------------------------------
for outer in range(self.num_outer):

# Transport --------------------------------------------------------
with timed_stage("Transport"):
self.io.log_courant(self.fields, 'transporting_velocity',
message=f'transporting velocity, outer iteration {outer}')
for name, scheme in self.active_transport:
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}')
# transports a field from xstar and puts result in xp
if name == self.predictor:
V = xstar(name).function_space()
field_in = Function(V)
field_out = Function(V)
self.predictor_interpolator.interpolate()
scheme.apply(field_out, field_in)
xp(name).assign(xstar(name) + field_out - field_in)
else:
scheme.apply(xp(name), xstar(name))

# Fast physics -----------------------------------------------------
self.transport_field(name, scheme, xstar, xp)

x_after_fast(self.field_name).assign(xp(self.field_name))
if len(self.fast_physics_schemes) > 0:
with timed_stage("Fast physics"):
Expand All @@ -356,7 +351,8 @@ def timestep(self):

for inner in range(self.num_inner):

# Implicit forcing ---------------------------------------------
# TODO: this is where to update the reference state

with timed_stage("Apply forcing terms"):
logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}')
self.forcing.apply(xp, xnp1, xrhs, "implicit")
Expand All @@ -367,7 +363,6 @@ def timestep(self):
xrhs -= xnp1(self.field_name)
xrhs += xrhs_phys

# Linear solve -------------------------------------------------
with timed_stage("Implicit solve"):
logger.info(f'Semi-implicit Quasi Newton: Mixed solve {(outer, inner)}')
self.linear_solver.solve(xrhs, dy) # solves linear system and places result in dy
Expand Down Expand Up @@ -407,18 +402,10 @@ def run(self, t, tmax, pick_up=False):
pick_up: (bool): specify whether to pick_up from a previous run
"""

if not pick_up and self.reference_update_freq is None:
if not pick_up:
assert self.reference_profiles_initialised, \
'Reference profiles for must be initialised to use Semi-Implicit Timestepper'

if not pick_up and self.reference_update_freq is not None:
# Force reference profiles to be updated on first time step
self.last_ref_update_time = float(t) - float(self.dt)

elif not pick_up or (pick_up and self.reference_update_freq is None):
# Indicate that linear solver profile needs updating
self.to_update_ref_profile = True

super().run(t, tmax, pick_up=pick_up)


Expand Down

0 comments on commit 0678bae

Please sign in to comment.