Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reference profiles to be updated #542

Merged
merged 9 commits into from
Nov 1, 2024
67 changes: 58 additions & 9 deletions gusto/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@
from pyop2.mpi import MPI
import numpy as np
from gusto.core.logging import logger, update_logfile_location
from collections import namedtuple

__all__ = ["pick_up_mesh", "IO"]
__all__ = ["pick_up_mesh", "IO", "TimeData"]


class GustoIOError(IOError):
pass


# A named tuple object encapsulating data about timing
TimeData = namedtuple(
'TimeData',
['t', 'step', 'initial_steps', 'last_ref_update_time']
)


def pick_up_mesh(output, mesh_name):
"""
Picks up a checkpointed mesh. This must be the first step of any model being
Expand Down Expand Up @@ -531,7 +539,14 @@ def setup_dump(self, state_fields, t, pick_up=False):

# dump initial fields
if not pick_up:
self.dump(state_fields, t, step=1)
step = 1
last_ref_update_time = None
initial_steps = None
time_data = TimeData(
t=t, step=step, initial_steps=initial_steps,
last_ref_update_time=last_ref_update_time
)
self.dump(state_fields, time_data)

def pick_up_from_checkpoint(self, state_fields):
"""
Expand All @@ -541,7 +556,10 @@ def pick_up_from_checkpoint(self, state_fields):
state_fields (:class:`StateFields`): the model's field container.

Returns:
float: the checkpointed model time.
tuple of (`time_data`, `reference_profiles`): where `time_data`
itself is a named tuple containing the timing data.
The `reference_profiles` are a list of (`field_name`, expr)
pairs describing the reference profile fields.
"""

# -------------------------------------------------------------------- #
Expand Down Expand Up @@ -602,6 +620,13 @@ def pick_up_from_checkpoint(self, state_fields):
except AttributeError:
initial_steps = None

# Try to pick up number last_ref_update_time
# Not compulsory so errors allowed
try:
last_ref_update_time = chk.read_attribute("/", "last_ref_update_time")
except AttributeError:
last_ref_update_time = None

# Finally pick up time and step number
t = chk.read_attribute("/", "time")
step = chk.read_attribute("/", "step")
Expand Down Expand Up @@ -632,6 +657,13 @@ def pick_up_from_checkpoint(self, state_fields):
else:
initial_steps = None

# Try to pick up last reference profile update time
# Not compulsory so errors allowed
if chk.has_attr("/", "last_ref_update_time"):
last_ref_update_time = chk.get_attr("/", "last_ref_update_time")
else:
last_ref_update_time = None

# Finally pick up time
t = chk.get_attr("/", "time")
step = chk.get_attr("/", "step")
Expand All @@ -647,9 +679,14 @@ def pick_up_from_checkpoint(self, state_fields):
if hasattr(diagnostic_field, "init_field_set"):
diagnostic_field.init_field_set = True

return t, reference_profiles, step, initial_steps
time_data = TimeData(
t=t, step=step, initial_steps=initial_steps,
last_ref_update_time=last_ref_update_time
)

return time_data, reference_profiles

def dump(self, state_fields, t, step, initial_steps=None):
def dump(self, state_fields, time_data):
"""
Dumps all of the required model output.

Expand All @@ -659,12 +696,20 @@ def dump(self, state_fields, t, step, initial_steps=None):

Args:
state_fields (:class:`StateFields`): the model's field container.
t (float): the simulation's current time.
step (int): the number of time steps.
initial_steps (int, optional): the number of initial time steps
completed by a multi-level time scheme. Defaults to None.
time_data (namedtuple): contains information relating to the time in
the simulation. The tuple is structured as follows:
- t: current time in s
- step: the index of the time step
- initial_steps: number of initial time steps completed by a
multi-level time scheme (could be None)
- last_ref_update_time: the last time in s that the reference
profiles were updated (could be None)
"""
output = self.output
t = time_data.t
step = time_data.step
initial_steps = time_data.initial_steps
last_ref_update_time = time_data.last_ref_update_time

# Diagnostics:
# Compute diagnostic fields
Expand All @@ -688,6 +733,8 @@ def dump(self, state_fields, t, step, initial_steps=None):
self.chkpt.write_attribute("/", "step", step)
if initial_steps is not None:
self.chkpt.write_attribute("/", "initial_steps", initial_steps)
if last_ref_update_time is not None:
self.chkpt.write_attribute("/", "last_ref_update_time", last_ref_update_time)
else:
with CheckpointFile(self.chkpt_path, 'w') as chk:
chk.save_mesh(self.domain.mesh)
Expand All @@ -697,6 +744,8 @@ def dump(self, state_fields, t, step, initial_steps=None):
chk.set_attr("/", "step", step)
if initial_steps is not None:
chk.set_attr("/", "initial_steps", initial_steps)
if last_ref_update_time is not None:
chk.set_attr("/", "last_ref_update_time", last_ref_update_time)

if (next(self.dumpcount) % output.dumpfreq) == 0:
if output.dump_nc:
Expand Down
26 changes: 16 additions & 10 deletions gusto/solvers/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from abc import ABCMeta, abstractmethod, abstractproperty


__all__ = ["BoussinesqSolver", "LinearTimesteppingSolver", "CompressibleSolver", "ThermalSWSolver", "MoistConvectiveSWSolver"]
__all__ = ["BoussinesqSolver", "LinearTimesteppingSolver", "CompressibleSolver",
"ThermalSWSolver", "MoistConvectiveSWSolver"]


class TimesteppingSolver(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -374,6 +375,20 @@ def L_tr(f):
python_context = self.hybridized_solver.snes.ksp.pc.getPythonContext()
attach_custom_monitor(python_context, logging_ksp_monitor_true_residual)

@timed_function("Gusto:UpdateReferenceProfiles")
def update_reference_profiles(self):
"""
Updates the reference profiles.
"""

with timed_region("Gusto:HybridProjectRhobar"):
logger.info('Compressible linear solver: rho average solve')
self.rho_avg_solver.solve()

with timed_region("Gusto:HybridProjectExnerbar"):
logger.info('Compressible linear solver: Exner average solve')
self.exner_avg_solver.solve()

@timed_function("Gusto:LinearSolve")
def solve(self, xrhs, dy):
"""
Expand All @@ -387,15 +402,6 @@ def solve(self, xrhs, dy):
"""
self.xrhs.assign(xrhs)

# TODO: can we avoid computing these each time the solver is called?
with timed_region("Gusto:HybridProjectRhobar"):
logger.info('Compressible linear solver: rho average solve')
self.rho_avg_solver.solve()

with timed_region("Gusto:HybridProjectExnerbar"):
logger.info('Compressible linear solver: Exner average solve')
self.exner_avg_solver.solve()

# Solve the hybridized system
logger.info('Compressible linear solver: hybridized solve')
self.hybridized_solver.solve()
Expand Down
59 changes: 52 additions & 7 deletions gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
diffusion_schemes=None, physics_schemes=None,
slow_physics_schemes=None, fast_physics_schemes=None,
alpha=Constant(0.5), off_centred_u=False,
num_outer=2, num_inner=2, accelerator=False):
num_outer=2, num_inner=2, accelerator=False,
reference_update_freq=None):

"""
Args:
Expand Down Expand Up @@ -84,13 +85,24 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
implicit forcing (pressure gradient and Coriolis) terms, and the
linear solve. Defaults to 2. Note that default used by the Met
Office's ENDGame and GungHo models is 2.
accelerator (bool, optional): Whether to zero non-wind implicit forcings
for transport terms in order to speed up solver convergence
accelerator (bool, optional): Whether to zero non-wind implicit
forcings for transport terms in order to speed up solver
convergence. Defaults to False.
reference_update_freq (float, optional): frequency with which to
update the reference profile with the n-th time level state
fields. This variable corresponds to time in seconds, and
setting this to zero will update the reference profiles every
time step. Setting it to None turns off the update, and
reference profiles will remain at their initial values.
Defaults to None.
"""

self.num_outer = num_outer
self.num_inner = num_inner
self.alpha = alpha
self.accelerator = accelerator
self.reference_update_freq = reference_update_freq
self.to_update_ref_profile = False

# default is to not offcentre transporting velocity but if it
# is offcentred then use the same value as alpha
Expand Down Expand Up @@ -188,7 +200,6 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.linear_solver = linear_solver
self.forcing = Forcing(equation_set, self.alpha)
self.bcs = equation_set.bcs
self.accelerator = accelerator

def _apply_bcs(self):
"""
Expand Down Expand Up @@ -252,6 +263,24 @@ def copy_active_tracers(self, x_in, x_out):
for name in self.tracers_to_copy:
x_out(name).assign(x_in(name))

def update_reference_profiles(self):
"""
Updates the reference profiles and if required also updates them in the
linear solver.
"""

if self.reference_update_freq is not None:
if float(self.t) + self.reference_update_freq > self.last_ref_update_time:
self.equation.X_ref.assign(self.x.n(self.field_name))
self.last_ref_update_time = float(self.t)
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()

elif self.to_update_ref_profile:
if hasattr(self.linear_solver, 'update_reference_profiles'):
self.linear_solver.update_reference_profiles()
self.to_update_ref_profile = False

def timestep(self):
"""Defines the timestep"""
xn = self.x.n
Expand All @@ -264,13 +293,18 @@ def timestep(self):
xrhs_phys = self.xrhs_phys
dy = self.dy

# Update reference profiles --------------------------------------------
self.update_reference_profiles()

# Slow physics ---------------------------------------------------------
x_after_slow(self.field_name).assign(xn(self.field_name))
if len(self.slow_physics_schemes) > 0:
with timed_stage("Slow physics"):
logger.info('Semi-implicit Quasi Newton: Slow physics')
for _, scheme in self.slow_physics_schemes:
scheme.apply(x_after_slow(scheme.field_name), x_after_slow(scheme.field_name))

# Explict forcing ------------------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info('Semi-implicit Quasi Newton: Explicit forcing')
# Put explicit forcing into xstar
Expand All @@ -280,8 +314,10 @@ def timestep(self):
# the correct values
xp(self.field_name).assign(xstar(self.field_name))

# OUTER ----------------------------------------------------------------
for outer in range(self.num_outer):

# Transport --------------------------------------------------------
with timed_stage("Transport"):
self.io.log_courant(self.fields, 'transporting_velocity',
message=f'transporting velocity, outer iteration {outer}')
Expand All @@ -290,6 +326,7 @@ def timestep(self):
# transports a field from xstar and puts result in xp
scheme.apply(xp(name), xstar(name))

# Fast physics -----------------------------------------------------
x_after_fast(self.field_name).assign(xp(self.field_name))
if len(self.fast_physics_schemes) > 0:
with timed_stage("Fast physics"):
Expand All @@ -302,8 +339,7 @@ def timestep(self):

for inner in range(self.num_inner):

# TODO: this is where to update the reference state
jshipton marked this conversation as resolved.
Show resolved Hide resolved

# Implicit forcing ---------------------------------------------
with timed_stage("Apply forcing terms"):
logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}')
self.forcing.apply(xp, xnp1, xrhs, "implicit")
Expand All @@ -314,6 +350,7 @@ def timestep(self):
xrhs -= xnp1(self.field_name)
xrhs += xrhs_phys

# Linear solve -------------------------------------------------
with timed_stage("Implicit solve"):
logger.info(f'Semi-implicit Quasi Newton: Mixed solve {(outer, inner)}')
self.linear_solver.solve(xrhs, dy) # solves linear system and places result in dy
Expand Down Expand Up @@ -353,10 +390,18 @@ def run(self, t, tmax, pick_up=False):
pick_up: (bool): specify whether to pick_up from a previous run
"""

if not pick_up:
if not pick_up and self.reference_update_freq is None:
assert self.reference_profiles_initialised, \
'Reference profiles for must be initialised to use Semi-Implicit Timestepper'

if not pick_up and self.reference_update_freq is not None:
# Force reference profiles to be updated on first time step
self.last_ref_update_time = float(t) - float(self.dt)

elif not pick_up or (pick_up and self.reference_update_freq is None):
# Indicate that linear solver profile needs updating
self.to_update_ref_profile = True

super().run(t, tmax, pick_up=pick_up)


Expand Down
Loading
Loading