Skip to content

Commit

Permalink
PR #536: from firedrakeproject/solver_improvements
Browse files Browse the repository at this point in the history
Solver Convergence Improvements
  • Loading branch information
tommbendall authored Aug 16, 2024
2 parents bd2ef2b + a88e3cb commit e4cbef6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 41 deletions.
104 changes: 64 additions & 40 deletions gusto/solvers/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
class TimesteppingSolver(object, metaclass=ABCMeta):
"""Base class for timestepping linear solvers for Gusto."""

def __init__(self, equations, alpha=0.5, solver_parameters=None,
overwrite_solver_parameters=False):
def __init__(self, equations, alpha=0.5, tau_values=None,
solver_parameters=None, overwrite_solver_parameters=False):
"""
Args:
equations (:class:`PrognosticEquation`): the model's equation.
alpha (float, optional): the semi-implicit off-centring factor.
Defaults to 0.5. A value of 1 is fully-implicit.
tau_values (dict, optional): contains the semi-implicit relaxation
parameters. Defaults to None, in which case the value of alpha is used.
solver_parameters (dict, optional): contains the options to be
passed to the underlying :class:`LinearVariationalSolver`.
Defaults to None.
Expand All @@ -51,6 +53,7 @@ def __init__(self, equations, alpha=0.5, solver_parameters=None,
self.equations = equations
self.dt = equations.domain.dt
self.alpha = alpha
self.tau_values = tau_values if tau_values is not None else {}

if solver_parameters is not None:
if not overwrite_solver_parameters:
Expand Down Expand Up @@ -135,14 +138,16 @@ class CompressibleSolver(TimesteppingSolver):
'pc_type': 'bjacobi',
'sub_pc_type': 'ilu'}}}

def __init__(self, equations, alpha=0.5,
def __init__(self, equations, alpha=0.5, tau_values=None,
quadrature_degree=None, solver_parameters=None,
overwrite_solver_parameters=False):
"""
Args:
equations (:class:`PrognosticEquation`): the model's equation.
alpha (float, optional): the semi-implicit off-centring factor.
Defaults to 0.5. A value of 1 is fully-implicit.
tau_values (dict, optional): contains the semi-implicit relaxation
parameters. Defaults to None, in which case the value of alpha is used.
quadrature_degree (tuple, optional): a tuple (q_h, q_v) where q_h is
the required quadrature degree in the horizontal direction and
q_v is that in the vertical direction. Defaults to None.
Expand All @@ -164,24 +169,31 @@ def __init__(self, equations, alpha=0.5,
logger.warning("default quadrature degree most likely not sufficient for this degree element")
self.quadrature_degree = (5, 5)

super().__init__(equations, alpha, solver_parameters,
super().__init__(equations, alpha, tau_values, solver_parameters,
overwrite_solver_parameters)

@timed_function("Gusto:SolverSetup")
def _setup_solver(self):

equations = self.equations
dt = self.dt
beta_ = dt*self.alpha
# Set relaxation parameters. If an alternative has not been given, set
# to semi-implicit off-centering factor
beta_u_ = dt*self.tau_values.get("u", self.alpha)
beta_t_ = dt*self.tau_values.get("theta", self.alpha)
beta_r_ = dt*self.tau_values.get("rho", self.alpha)

cp = equations.parameters.cp
Vu = equations.domain.spaces("HDiv")
Vu_broken = FunctionSpace(equations.domain.mesh, BrokenElement(Vu.ufl_element()))
Vtheta = equations.domain.spaces("theta")
Vrho = equations.domain.spaces("DG")

# Store time-stepping coefficients as UFL Constants
beta = Constant(beta_)
beta_cp = Constant(beta_ * cp)
beta_u = Constant(beta_u_)
beta_t = Constant(beta_t_)
beta_r = Constant(beta_r_)
beta_u_cp = Constant(beta_u * cp)

h_deg = Vrho.ufl_element().degree()[0]
v_deg = Vrho.ufl_element().degree()[1]
Expand All @@ -206,7 +218,7 @@ def _setup_solver(self):

# Analytical (approximate) elimination of theta
k = equations.domain.k # Upward pointing unit vector
theta = -dot(k, u)*dot(k, grad(thetabar))*beta + theta_in
theta = -dot(k, u)*dot(k, grad(thetabar))*beta_t + theta_in

# Only include theta' (rather than exner') in the vertical
# component of the gradient
Expand Down Expand Up @@ -285,21 +297,21 @@ def L_tr(f):
eqn = (
# momentum equation
u_mass
- beta_cp*div(theta_w*V(w))*exnerbar*dxp
- beta_u_cp*div(theta_w*V(w))*exnerbar*dxp
# following does nothing but is preserved in the comments
# to remind us why (because V(w) is purely vertical).
# + beta_cp*jump(theta_w*V(w), n=n)*exnerbar_avg('+')*dS_vp
+ beta_cp*jump(theta_w*V(w), n=n)*exnerbar_avg('+')*dS_hp
+ beta_cp*dot(theta_w*V(w), n)*exnerbar_avg*ds_tbp
- beta_cp*div(thetabar_w*w)*exner*dxp
+ beta_u_cp*jump(theta_w*V(w), n=n)*exnerbar_avg('+')*dS_hp
+ beta_u_cp*dot(theta_w*V(w), n)*exnerbar_avg*ds_tbp
- beta_u_cp*div(thetabar_w*w)*exner*dxp
# trace terms appearing after integrating momentum equation
+ beta_cp*jump(thetabar_w*w, n=n)*l0('+')*(dS_vp + dS_hp)
+ beta_cp*dot(thetabar_w*w, n)*l0*(ds_tbp + ds_vp)
+ beta_u_cp*jump(thetabar_w*w, n=n)*l0('+')*(dS_vp + dS_hp)
+ beta_u_cp*dot(thetabar_w*w, n)*l0*(ds_tbp + ds_vp)
# mass continuity equation
+ (phi*(rho - rho_in) - beta*inner(grad(phi), u)*rhobar)*dx
+ beta*jump(phi*u, n=n)*rhobar_avg('+')*(dS_v + dS_h)
+ (phi*(rho - rho_in) - beta_r*inner(grad(phi), u)*rhobar)*dx
+ beta_r*jump(phi*u, n=n)*rhobar_avg('+')*(dS_v + dS_h)
# term added because u.n=0 is enforced weakly via the traces
+ beta*phi*dot(u, n)*rhobar_avg*(ds_tb + ds_v)
+ beta_r*phi*dot(u, n)*rhobar_avg*(ds_tb + ds_v)
# constraint equation to enforce continuity of the velocity
# through the interior facets and weakly impose the no-slip
# condition
Expand Down Expand Up @@ -342,7 +354,7 @@ def L_tr(f):

self.theta = Function(Vtheta)
theta_eqn = gamma*(theta - theta_in
+ dot(k, self.u_hdiv)*dot(k, grad(thetabar))*beta)*dx
+ dot(k, self.u_hdiv)*dot(k, grad(thetabar))*beta_t)*dx

theta_problem = LinearVariationalProblem(lhs(theta_eqn), rhs(theta_eqn), self.theta)
self.theta_solver = LinearVariationalSolver(theta_problem,
Expand Down Expand Up @@ -446,13 +458,19 @@ def _setup_solver(self):
equation = self.equations # just cutting down line length a bit

dt = self.dt
beta_ = dt*self.alpha
# Set relaxation parameters. If an alternative has not been given, set
# to semi-implicit off-centering factor
beta_u_ = dt*self.tau_values.get("u", self.alpha)
beta_p_ = dt*self.tau_values.get("p", self.alpha)
beta_b_ = dt*self.tau_values.get("b", self.alpha)
Vu = equation.domain.spaces("HDiv")
Vb = equation.domain.spaces("theta")
Vp = equation.domain.spaces("DG")

# Store time-stepping coefficients as UFL Constants
beta = Constant(beta_)
beta_u = Constant(beta_u_)
beta_p = Constant(beta_p_)
beta_b = Constant(beta_b_)

# Split up the rhs vector (symbolically)
self.xrhs = Function(self.equations.function_space)
Expand All @@ -468,21 +486,21 @@ def _setup_solver(self):

# Analytical (approximate) elimination of theta
k = equation.domain.k # Upward pointing unit vector
b = -dot(k, u)*dot(k, grad(bbar))*beta + b_in
b = -dot(k, u)*dot(k, grad(bbar))*beta_b + b_in

# vertical projection
def V(u):
return k*inner(u, k)

eqn = (
inner(w, (u - u_in))*dx
- beta*div(w)*p*dx
- beta*inner(w, k)*b*dx
- beta_u*div(w)*p*dx
- beta_u*inner(w, k)*b*dx
)

if equation.compressible:
cs = equation.parameters.cs
eqn += phi * (p - p_in) * dx + beta * phi * cs**2 * div(u) * dx
eqn += phi * (p - p_in) * dx + beta_p * phi * cs**2 * div(u) * dx
else:
eqn += phi * div(u) * dx

Expand Down Expand Up @@ -519,7 +537,7 @@ def trace_nullsp(T):
self.b = Function(Vb)

b_eqn = gamma*(b - b_in
+ dot(k, u)*dot(k, grad(bbar))*beta)*dx
+ dot(k, u)*dot(k, grad(bbar))*beta_b)*dx

b_problem = LinearVariationalProblem(lhs(b_eqn),
rhs(b_eqn),
Expand Down Expand Up @@ -589,7 +607,9 @@ class ThermalSWSolver(TimesteppingSolver):
def _setup_solver(self):
equation = self.equations # just cutting down line length a bit
dt = self.dt
beta_ = dt*self.alpha
beta_u_ = dt*self.tau_values.get("u", self.alpha)
beta_d_ = dt*self.tau_values.get("D", self.alpha)
beta_b_ = dt*self.tau_values.get("b", self.alpha)
Vu = equation.domain.spaces("HDiv")
VD = equation.domain.spaces("DG")
Vb = equation.domain.spaces("DG")
Expand All @@ -599,7 +619,9 @@ def _setup_solver(self):
raise NotImplementedError("Field 'b' must exist to use the thermal linear solver in the SIQN scheme")

# Store time-stepping coefficients as UFL Constants
beta = Constant(beta_)
beta_u = Constant(beta_u_)
beta_d = Constant(beta_d_)
beta_b = Constant(beta_b_)

# Split up the rhs vector
self.xrhs = Function(self.equations.function_space)
Expand All @@ -617,20 +639,20 @@ def _setup_solver(self):
bbar = split(equation.X_ref)[2]

# Approximate elimination of b
b = -dot(u, grad(bbar))*beta + b_in
b = -dot(u, grad(bbar))*beta_b + b_in

n = FacetNormal(equation.domain.mesh)

eqn = (
inner(w, (u - u_in)) * dx
- beta * (D - Dbar) * div(w*bbar) * dx
+ beta * jump(w*bbar, n) * avg(D-Dbar) * dS
- beta * 0.5 * Dbar * bbar * div(w) * dx
- beta * 0.5 * Dbar * b * div(w) * dx
- beta * 0.5 * bbar * div(w*(D-Dbar)) * dx
+ beta * 0.5 * jump((D-Dbar)*w, n) * avg(bbar) * dS
- beta_u * (D - Dbar) * div(w*bbar) * dx
+ beta_u * jump(w*bbar, n) * avg(D-Dbar) * dS
- beta_u * 0.5 * Dbar * bbar * div(w) * dx
- beta_u * 0.5 * Dbar * b * div(w) * dx
- beta_u * 0.5 * bbar * div(w*(D-Dbar)) * dx
+ beta_u * 0.5 * jump((D-Dbar)*w, n) * avg(bbar) * dS
+ inner(phi, (D - D_in)) * dx
+ beta * phi * Dbar * div(u) * dx
+ beta_d * phi * Dbar * div(u) * dx
)

aeqn = lhs(eqn)
Expand Down Expand Up @@ -660,7 +682,7 @@ def trace_nullsp(T):
u, D = self.uD.subfunctions
self.b = Function(Vb)

b_eqn = gamma*(b - b_in + inner(u, grad(bbar))*beta) * dx
b_eqn = gamma*(b - b_in + inner(u, grad(bbar))*beta_b) * dx

b_problem = LinearVariationalProblem(lhs(b_eqn),
rhs(b_eqn),
Expand Down Expand Up @@ -814,12 +836,14 @@ class MoistConvectiveSWSolver(TimesteppingSolver):
def _setup_solver(self):
equation = self.equations # just cutting down line length a bit
dt = self.dt
beta_ = dt*self.alpha
beta_u_ = dt*self.tau_values.get("u", self.alpha)
beta_d_ = dt*self.tau_values.get("D", self.alpha)
Vu = equation.domain.spaces("HDiv")
VD = equation.domain.spaces("DG")

# Store time-stepping coefficients as UFL Constants
beta = Constant(beta_)
beta_u = Constant(beta_u_)
beta_d = Constant(beta_d_)

# Split up the rhs vector
self.xrhs = Function(self.equations.function_space)
Expand All @@ -838,9 +862,9 @@ def _setup_solver(self):

eqn = (
inner(w, (u - u_in)) * dx
- beta * (D - Dbar) * div(w*g) * dx
- beta_u * (D - Dbar) * div(w*g) * dx
+ inner(phi, (D - D_in)) * dx
+ beta * phi * Dbar * div(u) * dx
+ beta_d * phi * Dbar * div(u) * dx
)

aeqn = lhs(eqn)
Expand Down
31 changes: 30 additions & 1 deletion gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
diffusion_schemes=None, physics_schemes=None,
slow_physics_schemes=None, fast_physics_schemes=None,
alpha=Constant(0.5), off_centred_u=False,
num_outer=2, num_inner=2):
num_outer=2, num_inner=2, accelerator=False):

"""
Args:
Expand Down Expand Up @@ -84,6 +84,8 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
implicit forcing (pressure gradient and Coriolis) terms, and the
linear solve. Defaults to 2. Note that default used by the Met
Office's ENDGame and GungHo models is 2.
accelerator (bool, optional): Whether to zero non-wind implicit forcings
for transport terms in order to speed up solver convergence
"""

self.num_outer = num_outer
Expand Down Expand Up @@ -120,10 +122,12 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
+ f"physics scheme {parametrisation.label.label}")

self.active_transport = []
self.transported_fields = []
for scheme in transport_schemes:
assert scheme.nlevels == 1, "multilevel schemes not supported as part of this timestepping loop"
assert scheme.field_name in equation_set.field_names
self.active_transport.append((scheme.field_name, scheme))
self.transported_fields.append(scheme.field_name)
# Check that there is a corresponding transport method
method_found = False
for method in spatial_methods:
Expand Down Expand Up @@ -184,6 +188,7 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.linear_solver = linear_solver
self.forcing = Forcing(equation_set, self.alpha)
self.bcs = equation_set.bcs
self.accelerator = accelerator

def _apply_bcs(self):
"""
Expand Down Expand Up @@ -302,6 +307,9 @@ def timestep(self):
with timed_stage("Apply forcing terms"):
logger.info(f'Semi-implicit Quasi Newton: Implicit forcing {(outer, inner)}')
self.forcing.apply(xp, xnp1, xrhs, "implicit")
if (inner > 0 and self.accelerator):
# Zero implicit forcing to accelerate solver convergence
self.forcing.zero_forcing_terms(self.equation, xp, xrhs, self.transported_fields)

xrhs -= xnp1(self.field_name)
xrhs += xrhs_phys
Expand Down Expand Up @@ -477,3 +485,24 @@ def apply(self, x_in, x_nl, x_out, label):

x_out.assign(x_in(self.field_name))
x_out += self.xF

def zero_forcing_terms(self, equation, x_in, x_out, transported_field_names):
"""
Zero forcing term F(x) for non-wind transport.
This takes x_in and x_out, where \n
x_out = x_in + scale*F(x_nl) \n
for some field x_nl and sets x_out = x_in for all non-wind transport terms
Args:
equation (:class:`PrognosticEquationSet`): the prognostic
equation set to be solved
x_in (:class:`FieldCreator`): the field to be incremented.
x_out (:class:`FieldCreator`): the output field to be updated.
transported_field_names (str): list of fields names for transported fields
"""
for field_name in transported_field_names:
if field_name != 'u':
logger.info(f'Semi-Implicit Quasi Newton: Zeroing implicit forcing for {field_name}')
field_index = equation.field_names.index(field_name)
x_out.subfunctions[field_index].assign(x_in(field_name))

0 comments on commit e4cbef6

Please sign in to comment.