Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New base class DIRK_IMEX #112

Merged
merged 4 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demos/monodomain/demo_monodomain_FHN_dirkimex.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ We start with standard Firedrake/Irksome imports::
from firedrake import (And, Constant, File, Function, FunctionSpace,
RectangleMesh, SpatialCoordinate, TestFunctions,
as_matrix, conditional, dx, grad, inner, split)
from irksome import Dt, MeshConstant, DIRK_IMEX, TimeStepper
from irksome import Dt, MeshConstant, TimeStepper, ARS_DIRK_IMEX

And we set up the mesh and function space.::

Expand Down Expand Up @@ -75,7 +75,7 @@ This sets up the Butcher tableau. Here, we use the DIRK-IMEX methods proposed
by Ascher, Ruuth, and Spiteri in their 1997 Applied Numerical Mathematics paper.
For this case, We use a four-stage method.::

butcher_tableau = DIRK_IMEX(4, 4, 3)
butcher_tableau = ARS_DIRK_IMEX(4, 4, 3)
ns = butcher_tableau.num_stages

To access an IMEX method, we need to separately specify the implicit and explicit parts of the operator.
Expand Down
3 changes: 3 additions & 0 deletions irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from .pep_explicit_rk import PEPRK # noqa: F401
from .deriv import Dt # noqa: F401
from .dirk_imex_tableaux import DIRK_IMEX # noqa: F401
from .ars_dirk_imex_tableaux import ARS_DIRK_IMEX # noqa: F401
from .sspk_tableau import SSPK_DIRK_IMEX # noqa: F401
from .sspk_tableau import SSPButcherTableau # noqa: F401
from .dirk_stepper import DIRKTimeStepper # noqa: F401
from .getForm import getForm # noqa: F401
from .imex import RadauIIAIMEXMethod # noqa: F401
Expand Down
111 changes: 111 additions & 0 deletions irksome/ars_dirk_imex_tableaux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np

from .dirk_imex_tableaux import DIRK_IMEX

# Butcher tableau based on Ascher, Ruuth, and Spiteri Applied Numerical Mathematics 1997 (ARS)

# ARS tableau assume a zero first column of the implicit A matrix, so only the lower right s x s
# block is given for the implicit scheme. b and c are also of length s.

# Butcher tableau for s = 1
ars111A = np.array([[1.0]])
ars111A_hat = np.array([[1.0]])
ars111b = np.array([1.0])
ars111b_hat = np.array([1.0, 0.0])
ars111c = np.array([1.0])
ars111c_hat = np.array([0.0, 1.0])

ars121A = np.array([[1.0]])
ars121A_hat = np.array([[1.0]])
ars121b = np.array([1.0])
ars121b_hat = np.array([0.0, 1.0])
ars121c = np.array([1.0])
ars121c_hat = np.array([0.0, 1.0])

ars122A = np.array([[0.5]])
ars122A_hat = np.array([[0.5]])
ars122b = np.array([1.0])
ars122b_hat = np.array([0.0, 1.0])
ars122c = np.array([0.5])
ars122c_hat = np.array([0.0, 0.5])

# Butcher tableau for s = 2
gamma233 = (3 + np.sqrt(3))/6
ars233A = np.array([[gamma233, 0], [1 - 2*gamma233, gamma233]])
ars233A_hat = np.array([[gamma233, 0], [gamma233 - 1, 2*(1 - gamma233)]])
ars233b = np.array([0.5, 0.5])
ars233b_hat = np.array([0, 0.5, 0.5])
ars233c = np.array([gamma233, 1 - gamma233])
ars233c_hat = np.array([0, gamma233, 1 - gamma233])

gamma232 = (2 - np.sqrt(2)) / 2
delta232 = -2 * np.sqrt(2) / 3
ars232A = np.array([[gamma232, 0], [1 - gamma232, gamma232]])
ars232A_hat = np.array([[gamma232, 0], [delta232, 1 - delta232]])
ars232b = np.array([1 - gamma232, gamma232])
ars232b_hat = np.array([0, 1 - gamma232, gamma232])
ars232c = np.array([gamma232, 1.0])
ars232c_hat = np.array([0, gamma232, 1.0])

gamma222 = gamma232
delta222 = 1 - 1/(2*gamma222)
ars222A = np.array([[gamma222, 0], [1 - gamma222, gamma222]])
ars222A_hat = np.array([[gamma222, 0], [delta222, 1 - delta222]])
ars222b = np.array([1 - gamma222, gamma222])
ars222b_hat = np.array([delta222, 1 - delta222, 0])
ars222c = np.array([gamma222, 1.0])
ars222c_hat = np.array([0, gamma222, 1.0])

# Butcher tableau for s = 3
ars343A = np.array([[0.4358665215, 0, 0], [0.2820667392, 0.4358665215, 0], [1.208496649, -0.644363171, 0.4358665215]])
ars343A_hat = np.array([[0.4358665215, 0, 0], [0.3212788860, 0.3966543747, 0], [-0.105858296, 0.5529291479, 0.5529291479]])
ars343b = np.array([1.208496649, -0.644363171, 0.4358665215])
ars343b_hat = np.array([0, 1.208496649, -0.644363171, 0.4358665215])
ars343c = np.array([0.4358665215, 0.7179332608, 1])
ars343c_hat = np.array([0, 0.4358665215, 0.7179332608, 1.0])

# Butcher tableau for s = 4
ars443A = np.array([[1/2, 0, 0, 0],
[1/6, 1/2, 0, 0],
[-1/2, 1/2, 1/2, 0],
[3/2, -3/2, 1/2, 1/2]])
ars443A_hat = np.array([[1/2, 0, 0, 0],
[11/18, 1/18, 0, 0],
[5/6, -5/6, 1/2, 0],
[1/4, 7/4, 3/4, -7/4]])
ars443b = np.array([3/2, -3/2, 1/2, 1/2])
ars443b_hat = np.array([1/4, 7/4, 3/4, -7/4, 0])
ars443c = np.array([1/2, 2/3, 1/2, 1])
ars443c_hat = np.array([0, 1/2, 2/3, 1/2, 1])

ars_dict = {
(1, 1, 1): (ars111A, ars111b, ars111c, ars111A_hat, ars111b_hat, ars111c_hat),
(1, 2, 1): (ars121A, ars121b, ars121c, ars121A_hat, ars121b_hat, ars121c_hat),
(1, 2, 2): (ars122A, ars122b, ars122c, ars122A_hat, ars122b_hat, ars122c_hat),
(2, 2, 2): (ars222A, ars222b, ars222c, ars222A_hat, ars222b_hat, ars222c_hat),
(2, 3, 2): (ars232A, ars232b, ars232c, ars232A_hat, ars232b_hat, ars232c_hat),
(2, 3, 3): (ars233A, ars233b, ars233c, ars233A_hat, ars233b_hat, ars233c_hat),
(3, 4, 3): (ars343A, ars343b, ars343c, ars343A_hat, ars343b_hat, ars343c_hat),
(4, 4, 3): (ars443A, ars443b, ars443c, ars443A_hat, ars443b_hat, ars443c_hat)
}


class ARS_DIRK_IMEX(DIRK_IMEX):
"""Class to generate IMEX tableaux based on Ascher, Ruuth, and Spiteri (ARS). It has members

:arg ns_imp: number of implicit stages
:arg ns_exp: number of explicit stages
:arg order: the (integer) former order of accuracy of the method
"""
def __init__(self, ns_imp, ns_exp, order):
try:
A, b, c, A_hat, b_hat, c_hat = ars_dict[ns_imp, ns_exp, order]
except KeyError:
raise NotImplementedError("No ARS DIRK-IMEX method for that combination of implicit and explicit stages and order")

# Expand A, b, c with assumed zeros in ARS tableaux
A = self._pad_matrix(A, "lr")
b = np.append(np.zeros(1), b)
c = np.append(np.zeros(1), c)

super(ARS_DIRK_IMEX, self).__init__(A, b, c, A_hat, b_hat, c_hat, order)
113 changes: 50 additions & 63 deletions irksome/dirk_imex_tableaux.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,58 @@
from .ButcherTableaux import ButcherTableau
import numpy as np

# For the implicit scheme, the full Butcher Table is given as A, b, c.

# For the explicit scheme, the full b_hat and c_hat are given, but (to
# avoid a lot of offset-by-ones in the code we store only the
# lower-left ns x ns block of A_hat

# IMEX Butcher tableau for 1 stage
imex111A = np.array([[1.0]])
imex111A_hat = np.array([[1.0]])
imex111b = np.array([1.0])
imex111b_hat = np.array([1.0, 0.0])
imex111c = np.array([1.0])
imex111c_hat = np.array([0.0, 1.0])


# IMEX Butcher tableau for s = 2
gamma = (2 - np.sqrt(2)) / 2
delta = -2 * np.sqrt(2) / 3
imex232A = np.array([[gamma, 0], [1 - gamma, gamma]])
imex232A_hat = np.array([[gamma, 0], [delta, 1 - delta]])
imex232b = np.array([1 - gamma, gamma])
imex232b_hat = np.array([0, 1 - gamma, gamma])
imex232c = np.array([gamma, 1.0])
imex232c_hat = np.array([0, gamma, 1.0])

# IMEX Butcher tableau for 3 stages
imex343A = np.array([[0.4358665215, 0, 0], [0.2820667392, 0.4358665215, 0], [1.208496649, -0.644363171, 0.4358665215]])
imex343A_hat = np.array([[0.4358665215, 0, 0], [0.3212788860, 0.3966543747, 0], [-0.105858296, 0.5529291479, 0.5529291479]])
imex343b = np.array([1.208496649, -0.644363171, 0.4358665215])
imex343b_hat = np.array([0, 1.208496649, -0.644363171, 0.4358665215])
imex343c = np.array([0.4358665215, 0.7179332608, 1])
imex343c_hat = np.array([0, 0.4358665215, 0.7179332608, 1.0])


# IMEX Butcher tableau for 4 stages
imex443A = np.array([[1/2, 0, 0, 0],
[1/6, 1/2, 0, 0],
[-1/2, 1/2, 1/2, 0],
[3/2, -3/2, 1/2, 1/2]])
imex443A_hat = np.array([[1/2, 0, 0, 0],
[11/18, 1/18, 0, 0],
[5/6, -5/6, 1/2, 0],
[1/4, 7/4, 3/4, -7/4]])
imex443b = np.array([3/2, -3/2, 1/2, 1/2])
imex443b_hat = np.array([1/4, 7/4, 3/4, -7/4, 0])
imex443c = np.array([1/2, 2/3, 1/2, 1])
imex443c_hat = np.array([0, 1/2, 2/3, 1/2, 1])

dirk_imex_dict = {
(1, 1, 1): (imex111A, imex111b, imex111c, imex111A_hat, imex111b_hat, imex111c_hat),
(2, 3, 2): (imex232A, imex232b, imex232c, imex232A_hat, imex232b_hat, imex232c_hat),
(3, 4, 3): (imex343A, imex343b, imex343c, imex343A_hat, imex343b_hat, imex343c_hat),
(4, 4, 3): (imex443A, imex443b, imex443c, imex443A_hat, imex443b_hat, imex443c_hat)
}


class DIRK_IMEX(ButcherTableau):
def __init__(self, ns_imp, ns_exp, order):
try:
A, b, c, A_hat, b_hat, c_hat = dirk_imex_dict[ns_imp, ns_exp, order]
except KeyError:
raise NotImplementedError("No DIRK-IMEX method for that combination of implicit and explicit stages and order")
self.order = order
super(DIRK_IMEX, self).__init__(A, b, None, c, order, None, None)
self.A_hat = A_hat
"""Top-level class representing a pair of Butcher tableau encoding an implicit-explicit
additive Runge-Kutta method. Since the explicit Butcher matrix is strictly lower triangular,
only the lower-left (ns - 1)x(ns - 1) block is given. However, the full b_hat and c_hat are
given. It has members

:arg A: a 2d array containing the implicit Butcher matrix
:arg b: a 1d array giving weights assigned to each implicit stage when
computing the solution at time n+1.
:arg c: a 1d array containing weights at which time-dependent
implicit terms are evaluated.
:arg A_hat: a 2d array containing the explicit Butcher matrix (lower-left block only)
:arg b_hat: a 1d array giving weights assigned to each explicit stage when
computing the solution at time n+1.
:arg c_hat: a 1d array containing weights at which time-dependent
explicit terms are evaluated.
:arg order: the (integer) formal order of accuracy of the method
"""

def __init__(self, A: np.ndarray, b: np.ndarray, c: np.ndarray, A_hat: np.ndarray,
b_hat: np.ndarray, c_hat: np.ndarray, order: int = None):

# Number of stages
ns = A.shape[0]
assert ns == A.shape[1], "A must be square"
assert A_hat.shape == (ns - 1, ns - 1), "A_hat must have one fewer row and column than A"
assert ns == len(b) == len(b_hat), \
"b and b_hat must have the same length as the number of stages"
assert ns == len(c) == len(c_hat), \
"c and c_hat must have the same length as the number of stages"

super().__init__(A, b, None, c, order, None, None)
self.A_hat = self._pad_matrix(A_hat, "ll")
self.b_hat = b_hat
self.c_hat = c_hat
self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme

@staticmethod
def _pad_matrix(mat: np.ndarray, loc: str):
"""Zero pads a matrix"""
n = mat.shape[0]
assert n == mat.shape[1], "Matrix must be square"
padded = np.zeros((n+1, n+1), dtype=mat.dtype)

if loc == "ll":
# Lower left corner
padded[1:, :-1] = mat
elif loc == "lr":
# Lower right corner
padded[1:, 1:] = mat
else:
raise ValueError("Location must be ll (lower left) or lr (lower right)")

return padded
Loading