Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reference profiles to be updated #542

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
49 changes: 41 additions & 8 deletions gusto/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ def setup_dump(self, state_fields, t, pick_up=False):

# dump initial fields
if not pick_up:
self.dump(state_fields, t, step=1)
step = 1
last_ref_update_time = None
initial_steps = None
time_data = (t, step, initial_steps, last_ref_update_time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel slightly concerned about having a tuple of things that I might need to remember the order of! Could we use a named tuple?

Something like (completely untested!):

TimeData = namedtuple('TimeData', ['t', 'step', 'initial_steps', 'last_ref_update_time'])
time_data = TimeData(t, step, initial_steps, last_ref_update_time)

Then you can later access these like: t = time_data.t.

self.dump(state_fields, time_data)

def pick_up_from_checkpoint(self, state_fields):
"""
Expand All @@ -541,7 +545,12 @@ def pick_up_from_checkpoint(self, state_fields):
state_fields (:class:`StateFields`): the model's field container.

Returns:
float: the checkpointed model time.
tuple of (`time_data`, `reference_profiles`): where `time_data`
itself is a tuple of numbers relating to the checkpointed time.
This tuple is: (model time, step index,
number of initial steps, last time reference profiles updated).
The `reference_profiles` are a list of (`field_name`, expr)
pairs describing the reference profile fields.
"""

# -------------------------------------------------------------------- #
Expand Down Expand Up @@ -602,6 +611,13 @@ def pick_up_from_checkpoint(self, state_fields):
except AttributeError:
initial_steps = None

# Try to pick up number last_ref_update_time
# Not compulsory so errors allowed
try:
last_ref_update_time = chk.read_attribute("/", "last_ref_update_time")
except AttributeError:
last_ref_update_time = None

# Finally pick up time and step number
t = chk.read_attribute("/", "time")
step = chk.read_attribute("/", "step")
Expand Down Expand Up @@ -632,6 +648,13 @@ def pick_up_from_checkpoint(self, state_fields):
else:
initial_steps = None

# Try to pick up last reference profile update time
# Not compulsory so errors allowed
if chk.has_attr("/", "last_ref_update_time"):
last_ref_update_time = chk.get_attr("/", "last_ref_update_time")
else:
last_ref_update_time = None

# Finally pick up time
t = chk.get_attr("/", "time")
step = chk.get_attr("/", "step")
Expand All @@ -647,9 +670,10 @@ def pick_up_from_checkpoint(self, state_fields):
if hasattr(diagnostic_field, "init_field_set"):
diagnostic_field.init_field_set = True

return t, reference_profiles, step, initial_steps
time_data = (t, step, initial_steps, last_ref_update_time)
return time_data, reference_profiles

def dump(self, state_fields, t, step, initial_steps=None):
def dump(self, state_fields, time_data):
"""
Dumps all of the required model output.

Expand All @@ -659,12 +683,17 @@ def dump(self, state_fields, t, step, initial_steps=None):

Args:
state_fields (:class:`StateFields`): the model's field container.
t (float): the simulation's current time.
step (int): the number of time steps.
initial_steps (int, optional): the number of initial time steps
completed by a multi-level time scheme. Defaults to None.
time_data (tuple): contains information relating to the time in
the simulation. The tuple is structured as follows:
- t: current time in s
- step: the index of the time step
- initial_steps: number of initial time steps completed by a
multi-level time scheme (could be None)
- last_ref_update_time: the last time in s that the reference
profiles were updated (could be None)
"""
output = self.output
t, step, initial_steps, last_ref_update_time = time_data

# Diagnostics:
# Compute diagnostic fields
Expand All @@ -688,6 +717,8 @@ def dump(self, state_fields, t, step, initial_steps=None):
self.chkpt.write_attribute("/", "step", step)
if initial_steps is not None:
self.chkpt.write_attribute("/", "initial_steps", initial_steps)
if last_ref_update_time is not None:
self.chkpt.write_attribute("/", "last_ref_update_time", last_ref_update_time)
else:
with CheckpointFile(self.chkpt_path, 'w') as chk:
chk.save_mesh(self.domain.mesh)
Expand All @@ -697,6 +728,8 @@ def dump(self, state_fields, t, step, initial_steps=None):
chk.set_attr("/", "step", step)
if initial_steps is not None:
chk.set_attr("/", "initial_steps", initial_steps)
if last_ref_update_time is not None:
chk.set_attr("/", "last_ref_update_time", last_ref_update_time)

if (next(self.dumpcount) % output.dumpfreq) == 0:
if output.dump_nc:
Expand Down
26 changes: 16 additions & 10 deletions gusto/solvers/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from abc import ABCMeta, abstractmethod, abstractproperty


__all__ = ["BoussinesqSolver", "LinearTimesteppingSolver", "CompressibleSolver", "ThermalSWSolver", "MoistConvectiveSWSolver"]
__all__ = ["BoussinesqSolver", "LinearTimesteppingSolver", "CompressibleSolver",
"ThermalSWSolver", "MoistConvectiveSWSolver"]


class TimesteppingSolver(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -374,6 +375,20 @@ def L_tr(f):
python_context = self.hybridized_solver.snes.ksp.pc.getPythonContext()
attach_custom_monitor(python_context, logging_ksp_monitor_true_residual)

@timed_function("Gusto:UpdateReferenceProfiles")
def update_reference_profiles(self):
"""
Updates the reference profiles.
"""

with timed_region("Gusto:HybridProjectRhobar"):
logger.info('Compressible linear solver: rho average solve')
self.rho_avg_solver.solve()

with timed_region("Gusto:HybridProjectExnerbar"):
logger.info('Compressible linear solver: Exner average solve')
self.exner_avg_solver.solve()

@timed_function("Gusto:LinearSolve")
def solve(self, xrhs, dy):
"""
Expand All @@ -387,15 +402,6 @@ def solve(self, xrhs, dy):
"""
self.xrhs.assign(xrhs)

# TODO: can we avoid computing these each time the solver is called?
with timed_region("Gusto:HybridProjectRhobar"):
logger.info('Compressible linear solver: rho average solve')
self.rho_avg_solver.solve()

with timed_region("Gusto:HybridProjectExnerbar"):
logger.info('Compressible linear solver: Exner average solve')
self.exner_avg_solver.solve()

# Solve the hybridized system
logger.info('Compressible linear solver: hybridized solve')
self.hybridized_solver.solve()
Expand Down
59 changes: 52 additions & 7 deletions gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
diffusion_schemes=None, physics_schemes=None,
slow_physics_schemes=None, fast_physics_schemes=None,
alpha=Constant(0.5), off_centred_u=False,
num_outer=2, num_inner=2, accelerator=False):
num_outer=2, num_inner=2, accelerator=False,
reference_update_freq=None):

"""
Args:
Expand Down Expand Up @@ -84,13 +85,24 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
implicit forcing (pressure gradient and Coriolis) terms, and the
linear solve. Defaults to 2. Note that default used by the Met
Office's ENDGame and GungHo models is 2.
accelerator (bool, optional): Whether to zero non-wind implicit forcings
for transport terms in order to speed up solver convergence
accelerator (bool, optional): Whether to zero non-wind implicit
forcings for transport terms in order to speed up solver
convergence. Defaults to False.
reference_update_freq (float, optional): frequency with which to
update the reference profile with the n-th time level state
fields. This variable corresponds to time in seconds, and
setting this to zero will update the reference profiles every
time step. Setting it to None turns off the update, and
reference profiles will remain at their initial values.
Defaults to None.
"""

self.num_outer = num_outer
self.num_inner = num_inner
self.alpha = alpha
self.accelerator = accelerator
self.reference_update_freq = reference_update_freq
self.to_update_ref_profile = False

# default is to not offcentre transporting velocity but if it
# is offcentred then use the same value as alpha
Expand Down Expand Up @@ -188,7 +200,6 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.linear_solver = linear_solver
self.forcing = Forcing(equation_set, self.alpha)
self.bcs = equation_set.bcs
self.accelerator = accelerator

def _apply_bcs(self):
"""
Expand Down Expand Up @@ -252,6 +263,24 @@ def copy_active_tracers(self, x_in, x_out):
for name in self.tracers_to_copy:
x_out(name).assign(x_in(name))

def update_reference_profiles(self):
"""
Updates the reference profiles and if required also updates them in the
linear solver.
"""

if self.reference_update_freq is not None:
if float(self.t) + self.reference_update_freq > self.last_ref_update_time:
self.equation.X_ref.assign(self.x.n(self.field_name))
self.last_ref_update_time = float(self.t)
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()

elif self.to_update_ref_profile:
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()
self.to_update_ref_profile = False

def timestep(self):
"""Defines the timestep"""
xn = self.x.n
Expand All @@ -264,13 +293,18 @@ def timestep(self):
xrhs_phys = self.xrhs_phys
dy = self.dy

# Update reference profiles --------------------------------------------
self.update_reference_profiles()

# Slow physics ---------------------------------------------------------
x_after_slow(self.field_name).assign(xn(self.field_name))
if len(self.slow_physics_schemes) > 0:
with timed_stage("Slow physics"):
logger.info('Semi-implicit Quasi Newton: Slow physics')
for _, scheme in self.slow_physics_schemes:
scheme.apply(x_after_slow(scheme.field_name), x_after_slow(scheme.field_name))

# Explict forcing ------------------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info('Semi-implicit Quasi Newton: Explicit forcing')
# Put explicit forcing into xstar
Expand All @@ -280,8 +314,10 @@ def timestep(self):
# the correct values
xp(self.field_name).assign(xstar(self.field_name))

# OUTER ----------------------------------------------------------------
for outer in range(self.num_outer):

# Transport --------------------------------------------------------
with timed_stage("Transport"):
self.io.log_courant(self.fields, 'transporting_velocity',
message=f'transporting velocity, outer iteration {outer}')
Expand All @@ -290,6 +326,7 @@ def timestep(self):
# transports a field from xstar and puts result in xp
scheme.apply(xp(name), xstar(name))

# Fast physics -----------------------------------------------------
x_after_fast(self.field_name).assign(xp(self.field_name))
if len(self.fast_physics_schemes) > 0:
with timed_stage("Fast physics"):
Expand All @@ -302,8 +339,7 @@ def timestep(self):

for inner in range(self.num_inner):

# TODO: this is where to update the reference state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this comment just wrong?


# Implicit forcing ---------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}')
self.forcing.apply(xp, xnp1, xrhs, "implicit")
Expand All @@ -314,6 +350,7 @@ def timestep(self):
xrhs -= xnp1(self.field_name)
xrhs += xrhs_phys

# Linear solve -------------------------------------------------
with timed_stage("Implicit solve"):
logger.info(f'Semi-implicit Quasi Newton: Mixed solve {(outer, inner)}')
self.linear_solver.solve(xrhs, dy) # solves linear system and places result in dy
Expand Down Expand Up @@ -353,10 +390,18 @@ def run(self, t, tmax, pick_up=False):
pick_up: (bool): specify whether to pick_up from a previous run
"""

if not pick_up:
if not pick_up and self.reference_update_freq is None:
assert self.reference_profiles_initialised, \
'Reference profiles for must be initialised to use Semi-Implicit Timestepper'

if not pick_up and self.reference_update_freq is not None:
# Force reference profiles to be updated on first time step
self.last_ref_update_time = float(t) - float(self.dt)

elif not pick_up or (pick_up and self.reference_update_freq is None):
# Indicate that linear solver profile needs updating
self.to_update_ref_profile = True

super().run(t, tmax, pick_up=pick_up)


Expand Down
19 changes: 15 additions & 4 deletions gusto/timestepping/timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, equation, io):
self.dt = self.equation.domain.dt
self.t = self.equation.domain.t
self.reference_profiles_initialised = False
self.last_ref_update_time = None

self.setup_fields()
self.setup_scheme()
Expand Down Expand Up @@ -163,9 +164,11 @@ def run(self, t, tmax, pick_up=False):

if pick_up:
# Pick up fields, and return other info to be picked up
t, reference_profiles, self.step, initial_timesteps = self.io.pick_up_from_checkpoint(self.fields)
self.set_reference_profiles(reference_profiles)
time_data, reference_profiles = self.io.pick_up_from_checkpoint(self.fields)
t, self.step, initial_timesteps, last_ref_update_time = time_data
self.set_reference_profiles(reference_profiles, last_ref_update_time)
self.set_initial_timesteps(initial_timesteps)

else:
self.step = 1

Expand Down Expand Up @@ -193,21 +196,27 @@ def run(self, t, tmax, pick_up=False):
self.step += 1

with timed_stage("Dump output"):
self.io.dump(self.fields, float(self.t), self.step, self.get_initial_timesteps())
time_data = (
float(self.t), self.step,
self.get_initial_timesteps(), self.last_ref_update_time
)
self.io.dump(self.fields, time_data)

if self.io.output.checkpoint and self.io.output.checkpoint_method == 'dumbcheckpoint':
self.io.chkpt.close()

logger.info(f'TIMELOOP complete. t={float(self.t):.5f}, {tmax=:.5f}')

def set_reference_profiles(self, reference_profiles):
def set_reference_profiles(self, reference_profiles, last_ref_update_time=None):
"""
Initialise the model's reference profiles.

reference_profiles (list): an iterable of pairs: (field_name, expr),
where 'field_name' is the string giving the name of the reference
profile field expr is the :class:`ufl.Expr` whose value is used to
set the reference field.
last_ref_update_time (float, optional): the last time that the reference
profiles were updated. Defaults to None.
"""
for field_name, profile in reference_profiles:
if field_name+'_bar' in self.fields:
Expand Down Expand Up @@ -237,6 +246,8 @@ def set_reference_profiles(self, reference_profiles):
# Don't need to do anything else as value in field container has already been set
self.reference_profiles_initialised = True

self.last_ref_update_time = last_ref_update_time


class Timestepper(BaseTimestepper):
"""
Expand Down
Loading
Loading