Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increment form for implicit RK added and tested #566

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
self.butcher_matrix = butcher_matrix
self.nbutcher = int(np.shape(self.butcher_matrix)[0])
self.nStages = int(np.shape(self.butcher_matrix)[0])
self.rk_formulation = rk_formulation

@property
def nStages(self):
return self.nbutcher

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Set up the time discretisation based on the equation.
Expand Down
190 changes: 149 additions & 41 deletions gusto/time_discretisation/implicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np

from firedrake import (Function, split, NonlinearVariationalProblem,
NonlinearVariationalSolver)
NonlinearVariationalSolver, Constant)
from firedrake.fml import replace_subject, all_terms, drop
from firedrake.utils import cached_property

from gusto.core.labels import time_derivative
from gusto.time_discretisation.time_discretisation import (
TimeDiscretisation, wrapper_apply
)
from gusto.time_discretisation.explicit_runge_kutta import RungeKuttaFormulation


__all__ = ["ImplicitRungeKutta", "ImplicitMidpoint", "QinZhang"]
Expand All @@ -30,7 +31,9 @@ class ImplicitRungeKutta(TimeDiscretisation):
For each i = 1, s in an s stage method
we have the intermediate solutions: \n
y_i = y^n + dt*(a_i1*k_1 + a_i2*k_2 + ... + a_ii*k_i) \n
We compute the gradient at the intermediate location, k_i = F(y_i) \n
For the increment form we compute the gradient at the \n
intermediate location, k_i = F(y_i), whilst for the \n
predictor form we solve for each intermediate solution y_i. \n

At the last stage, compute the new solution by: \n
y^{n+1} = y^n + dt*(b_1*k_1 + b_2*k_2 + .... + b_s*k_s)
Expand All @@ -56,6 +59,7 @@ class ImplicitRungeKutta(TimeDiscretisation):
# ---------------------------------------------------------------------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth us making the predictor and increment forms clear in the docstrings?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made it a bit more clear, describing what we are solving for


def __init__(self, domain, butcher_matrix, field_name=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, options=None,):
"""
Args:
Expand All @@ -66,6 +70,9 @@ def __init__(self, domain, butcher_matrix, field_name=None,
discretisation.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
rk_formulation (:class:`RungeKuttaFormulation`, optional):
an enumerator object, describing the formulation of the Runge-
Kutta scheme. Defaults to the increment form.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -78,6 +85,7 @@ def __init__(self, domain, butcher_matrix, field_name=None,
options=options)
self.butcher_matrix = butcher_matrix
self.nStages = int(np.shape(self.butcher_matrix)[1])
self.rk_formulation = rk_formulation

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Expand All @@ -91,31 +99,105 @@ def setup(self, equation, apply_bcs=True, *active_labels):

super().setup(equation, apply_bcs, *active_labels)

self.k = [Function(self.fs) for i in range(self.nStages)]

def lhs(self):
return super().lhs

def rhs(self):
return super().rhs
if self.rk_formulation == RungeKuttaFormulation.predictor:
self.xs = [Function(self.fs) for _ in range(self.nStages)]
elif self.rk_formulation == RungeKuttaFormulation.increment:
self.k = [Function(self.fs) for _ in range(self.nStages)]
elif self.rk_formulation == RungeKuttaFormulation.linear:
raise NotImplementedError(
'Linear Implicit Runge-Kutta formulation is not implemented'
)
else:
raise NotImplementedError(
'Runge-Kutta formulation is not implemented'
)

def solver(self, stage):
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.xnph, self.idx),
)
def res(self, stage):
"""Set up the residual for the predictor formulation for a given stage."""
# Add time derivative terms y_s - y^n for stage s
mass_form = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual += mass_form.label_map(all_terms,
replace_subject(self.x_out, self.idx))
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calculate sum
# dt*(a_s1*F(y_1) + a_s2*F(y_2)+ ... + a_{s,s-1}*F(y_{s-1}))
for i in range(stage):
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.xs[i], old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[stage, i])*self.dt*t)
residual += r_imp
# Calculate and add on dt*a_ss*F(y_s)
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.x_out, old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[stage, stage])*self.dt*t)
residual += r_imp
return residual.form

@property
def final_res(self):
"""Set up the final residual for the predictor formulation."""
# Add time derivative terms y^{n+1} - y^n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make clear that this is only for the predictor formulation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
# dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s))
for i in range(self.nStages):
r_imp = self.residual.label_map(
lambda t: not t.has_label(time_derivative),
map_if_true=replace_subject(self.xs[i], old_idx=self.idx),
map_if_false=drop)
r_imp = r_imp.label_map(
all_terms,
map_if_true=lambda t: Constant(self.butcher_matrix[self.nStages, i])*self.dt*t)
residual += r_imp
return residual.form

problem = NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)
def solver(self, stage):
if self.rk_formulation == RungeKuttaFormulation.increment:
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.xnph, self.idx),
)
mass_form = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual += mass_form.label_map(all_terms,
replace_subject(self.x_out, self.idx))

problem = NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)

elif self.rk_formulation == RungeKuttaFormulation.predictor:
problem = NonlinearVariationalProblem(self.res(stage), self.x_out, bcs=self.bcs)

solver_name = self.field_name+self.__class__.__name__ + "%s" % (stage)
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters,
options_prefix=solver_name)
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@cached_property
def final_solver(self):
"""
Set up a solver for the final solve for the predictor
formulation to evaluate time level n+1.
"""
# setup solver using lhs and rhs defined in derived class
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@cached_property
def solvers(self):
Expand All @@ -126,32 +208,48 @@ def solvers(self):

def solve_stage(self, x0, stage):
self.x1.assign(x0)
for i in range(stage):
self.x1.assign(self.x1 + self.butcher_matrix[stage, i]*self.dt*self.k[i])
if self.rk_formulation == RungeKuttaFormulation.increment:
for i in range(stage):
self.x1.assign(self.x1 + self.butcher_matrix[stage, i]*self.dt*self.k[i])

if self.idx is None and len(self.fs) > 1:
self.xnph = tuple([self.dt*self.butcher_matrix[stage, stage]*a + b
for a, b in zip(split(self.x_out), split(self.x1))])
else:
self.xnph = self.x1 + self.butcher_matrix[stage, stage]*self.dt*self.x_out
solver = self.solvers[stage]
# Set initial guess for solver
if (stage > 0):
self.x_out.assign(self.k[stage-1])
if self.idx is None and len(self.fs) > 1:
self.xnph = tuple(
self.dt * self.butcher_matrix[stage, stage] * a + b
for a, b in zip(split(self.x_out), split(self.x1))
)
else:
self.xnph = self.x1 + self.butcher_matrix[stage, stage]*self.dt*self.x_out

solver.solve()
solver = self.solvers[stage]

self.k[stage].assign(self.x_out)
# Set initial guess for solver
if (stage > 0):
self.x_out.assign(self.k[stage-1])

solver.solve()
self.k[stage].assign(self.x_out)

elif self.rk_formulation == RungeKuttaFormulation.predictor:
if (stage > 0):
self.x_out.assign(self.xs[stage-1])
solver = self.solvers[stage]
solver.solve()

self.xs[stage].assign(self.x_out)

@wrapper_apply
def apply(self, x_out, x_in):

self.x_out.assign(x_in)
for i in range(self.nStages):
self.solve_stage(x_in, i)

x_out.assign(x_in)
for i in range(self.nStages):
x_out.assign(x_out + self.butcher_matrix[self.nStages, i]*self.dt*self.k[i])
if self.rk_formulation == RungeKuttaFormulation.increment:
x_out.assign(x_in)
for i in range(self.nStages):
x_out.assign(x_out + self.butcher_matrix[self.nStages, i]*self.dt*self.k[i])
elif self.rk_formulation == RungeKuttaFormulation.predictor:
self.final_solver.solve()
x_out.assign(self.x_out)


class ImplicitMidpoint(ImplicitRungeKutta):
Expand All @@ -164,14 +262,18 @@ class ImplicitMidpoint(ImplicitRungeKutta):
k0 = F[y^n + 0.5*dt*k0] \n
y^(n+1) = y^n + dt*k0 \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
options=None):
def __init__(self, domain, field_name=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
rk_formulation (:class:`RungeKuttaFormulation`, optional):
an enumerator object, describing the formulation of the Runge-
Kutta scheme. Defaults to the increment form.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -181,6 +283,7 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
"""
butcher_matrix = np.array([[0.5], [1.]])
super().__init__(domain, butcher_matrix, field_name,
rk_formulation=rk_formulation,
solver_parameters=solver_parameters,
options=options)

Expand All @@ -196,14 +299,18 @@ class QinZhang(ImplicitRungeKutta):
k1 = F[y^n + 0.5*dt*k0 + 0.25*dt*k1] \n
y^(n+1) = y^n + 0.5*dt*(k0 + k1) \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
options=None):
def __init__(self, domain, field_name=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
rk_formulation (:class:`RungeKuttaFormulation`, optional):
an enumerator object, describing the formulation of the Runge-
Kutta scheme. Defaults to the increment form.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -213,5 +320,6 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
"""
butcher_matrix = np.array([[0.25, 0], [0.5, 0.25], [0.5, 0.5]])
super().__init__(domain, butcher_matrix, field_name,
rk_formulation=rk_formulation,
solver_parameters=solver_parameters,
options=options)
6 changes: 3 additions & 3 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
operator F.
"""

from abc import ABCMeta, abstractmethod, abstractproperty
from abc import ABCMeta, abstractmethod
import math

from firedrake import (Function, TestFunction, TestFunctions, DirichletBC,
Expand Down Expand Up @@ -261,7 +261,7 @@ def setup(self, equation, apply_bcs=True, *active_labels):
def nlevels(self):
return 1

@abstractproperty
@property
def lhs(self):
"""Set up the discretisation's left hand side (the time derivative)."""
l = self.residual.label_map(
Expand All @@ -271,7 +271,7 @@ def lhs(self):

return l.form

@abstractproperty
@property
def rhs(self):
"""Set up the time discretisation's right hand side."""
r = self.residual.label_map(
Expand Down
11 changes: 7 additions & 4 deletions integration-tests/model/test_time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ def run(timestepper, tmax, f_end):

@pytest.mark.parametrize(
"scheme", [
"ssprk3_increment", "TrapeziumRule", "ImplicitMidpoint", "QinZhang",
"ssprk3_increment", "TrapeziumRule", "ImplicitMidpoint",
"QinZhang_increment", "QinZhang_predictor",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth", "Leapfrog",
"AdamsMoulton", "AdamsMoulton", "ssprk3_predictor", "ssprk3_linear"
"AdamsMoulton", "ssprk3_predictor", "ssprk3_linear"
]
)
def test_time_discretisation(tmpdir, scheme, tracer_setup):
Expand Down Expand Up @@ -40,8 +41,10 @@ def test_time_discretisation(tmpdir, scheme, tracer_setup):
transport_scheme = TrapeziumRule(domain)
elif scheme == "ImplicitMidpoint":
transport_scheme = ImplicitMidpoint(domain)
elif scheme == "QinZhang":
transport_scheme = QinZhang(domain)
elif scheme == "QinZhang_increment":
transport_scheme = QinZhang(domain, rk_formulation=RungeKuttaFormulation.increment)
elif scheme == "QinZhang_predictor":
transport_scheme = QinZhang(domain, rk_formulation=RungeKuttaFormulation.predictor)
elif scheme == "RK4":
transport_scheme = RK4(domain)
elif scheme == "Heun":
Expand Down
Loading