Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DIRK boundary conditions #103

Merged
merged 11 commits into from
Nov 20, 2024
153 changes: 59 additions & 94 deletions irksome/dirk_stepper.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,16 @@
import numpy
from firedrake import DirichletBC, Function
from firedrake import Function
from firedrake import NonlinearVariationalProblem as NLVP
from firedrake import NonlinearVariationalSolver as NLVS
from firedrake import assemble, split, project
from firedrake.__future__ import interpolate
from firedrake import split
from ufl.constantvalue import as_ufl

from .deriv import TimeDerivative
from .tools import replace, MeshConstant
from .bcs import bc2space


class BCThingy:
def __init__(self):
pass

def __call__(self, u):
return u


class BCCompOfNotMixedThingy:
def __init__(self, comp):
self.comp = comp

def __call__(self, u):
return u[self.comp]


class BCMixedBitThingy:
def __init__(self, sub):
self.sub = sub

def __call__(self, u):
return u.sub(self.sub)


class BCCompOfMixedBitThingy:
def __init__(self, sub, comp):
self.sub = sub
self.comp = comp

def __call__(self, u):
return u.sub(self.sub)[self.comp]


def getThingy(V, bc):
num_fields = len(V)
Vbc = bc.function_space()
if num_fields == 1:
comp = Vbc.component
if comp is None:
return BCThingy()
else:
return BCCompOfNotMixedThingy(comp)
else:
sub = bc.function_space_index()
comp = Vbc.component
if comp is None:
return BCMixedBitThingy(sub)
else:
return BCCompOfMixedBitThingy(sub, comp)


def getFormDIRK(F, butch, t, dt, u0, bcs=None):
def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None):
if bcs is None:
bcs = []

Expand All @@ -71,7 +20,7 @@ def getFormDIRK(F, butch, t, dt, u0, bcs=None):
assert V == u0.function_space()

num_fields = len(V)

num_stages = butch.num_stages
k = Function(V)
g = Function(V)

Expand Down Expand Up @@ -101,31 +50,31 @@ def getFormDIRK(F, butch, t, dt, u0, bcs=None):
stage_F = replace(F, repl)

bcnew = []
gblah = []

# For the DIRK case, we need one new BC for each old one (rather
# than one per stage), but we need a `Function` inside of each BC
# and a rule for computing that function at each time for each
# stage.
a_vals = numpy.array([MC.Constant(0) for i in range(num_stages)],
dtype=object)
d_val = MC.Constant(1.0)
for bc in bcs:
Vbc = bc.function_space()
bcarg = as_ufl(bc._original_arg)
bcarg_stage = replace(bcarg, {t: t+c*dt})
try:
gdat = assemble(interpolate(bcarg, Vbc))
gmethod = lambda gd, gc: gd.interpolate(gc)
except: # noqa: E722
gdat = assemble(project(bcarg, Vbc))
gmethod = lambda gd, gc: gd.project(gc)
if bcarg_stage == 0:
# Homogeneous BC, just zero out stage dofs
bcnew.append(bc.reconstruct(g=0))
continue

gdat = bcarg_stage - bc2space(bc, u0)
for i in range(num_stages):
gdat -= dt*a_vals[i]*bc2space(bc, ks[i])

new_bc = DirichletBC(Vbc, gdat, bc.sub_domain)
bcnew.append(new_bc)
gdat /= dt*d_val

dat4bc = getThingy(V, bc)
gdat2 = Function(gdat.function_space())
gblah.append((gdat, gdat2, bcarg_stage, gmethod, dat4bc))
bcnew.append(bc.reconstruct(g=gdat))

return stage_F, (k, g, a, c), bcnew, gblah
return stage_F, (k, g, a, c), bcnew, (a_vals, d_val)


class DIRKTimeStepper:
Expand All @@ -142,12 +91,33 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
self.num_linear_iterations = 0

self.butcher_tableau = butcher_tableau
self.num_stages = num_stages = butcher_tableau.num_stages
self.AAb = numpy.vstack((butcher_tableau.A, butcher_tableau.b))
self.CCone = numpy.append(butcher_tableau.c, 1.0)

# Need to be able to set BCs for either the DIRK or explicit cases.

# For DIRK, we say that the stage i solution should match the
# prescribed boundary values at time c[i], which means we use
# the i^th row of the Butcher tableau in determining the
# boundary condition

# For explicit, we say that the stage i solution should be
# determined to match the prescribed boundary values at time
# c[i+1] (the first stage where it appears), which means we
# use the (i+1)^st row of the Butcher tableau in determining
# the boundary condition, and the full reconstruction for the
# final stage

if butcher_tableau.is_explicit:
self.AAb = self.AAb[1:]
self.CCone = self.CCone[1:]

self.V = V = u0.function_space()
self.u0 = u0
self.t = t
self.dt = dt
self.num_fields = len(u0.function_space())
self.num_stages = num_stages = butcher_tableau.num_stages
self.ks = [Function(V) for _ in range(num_stages)]

# "k" is a generic function for which we will solve the
Expand All @@ -156,11 +126,10 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
# that we update as we go. We need to remember the
# stage values we've computed earlier in the time step...

stage_F, (k, g, a, c), bcnew, gblah = getFormDIRK(
F, butcher_tableau, t, dt, u0, bcs=bcs)
stage_F, (k, g, a, c), bcnew, (a_vals, d_val) = getFormDIRK(
F, self.ks, butcher_tableau, t, dt, u0, bcs=bcs)

self.bcnew = bcnew
self.gblah = gblah

appctx_irksome = {"F": F,
"butcher_tableau": butcher_tableau,
Expand All @@ -181,6 +150,19 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None,
nullspace=nullspace)

self.kgac = k, g, a, c
self.bc_constants = a_vals, d_val

def update_bc_constants(self, i, c):
AAb = self.AAb
CCone = self.CCone
a_vals, d_val = self.bc_constants
ns = AAb.shape[1]
for j in range(i):
a_vals[j].assign(AAb[i, j])
for j in range(i, ns):
a_vals[j].assign(0)
d_val.assign(AAb[i, i])
c.assign(CCone[i])

def advance(self):
k, g, a, c = self.kgac
Expand All @@ -189,14 +171,9 @@ def advance(self):
dtc = float(self.dt)
bt = self.butcher_tableau
AA = bt.A
CC = bt.c
BB = bt.b
gsplit = g.subfunctions
for i in range(self.num_stages):
# update a, c constants tucked into the variational problem
# for the current stage
a.assign(AA[i, i])
c.assign(CC[i])
# compute the already-known part of the state in the
# variational form
g.assign(u0)
Expand All @@ -205,21 +182,9 @@ def advance(self):
for (gbit, kbit) in zip(gsplit, ksplit):
gbit += dtc * float(AA[i, j]) * kbit

# update BC's for the variational problem
for (bc, (gdat, gdat2, gcur, gmethod, dat4bc)) in zip(self.bcnew, self.gblah):
# Evaluate the Dirichlet BC at the current stage time
gmethod(gdat, gcur)

gmethod(gdat2, dat4bc(u0))
gdat -= gdat2

# Subtract previous stage values
for j in range(i):
gmethod(gdat2, dat4bc(ks[j]))
gdat -= dtc * float(AA[i, j]) * gdat2

# Rescale gdat
gdat /= dtc * float(AA[i, i])
# update BC constants for the variational problem
self.update_bc_constants(i, c)
a.assign(AA[i, i])

# solve new variational problem, stash the computed
# stage value.
Expand Down
74 changes: 74 additions & 0 deletions tests/test_explicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from math import isclose

import pytest
from firedrake import *
from irksome import PEPRK, Dt, MeshConstant, TimeStepper
from ufl.algorithms.ad import expand_derivatives

peprks = [PEPRK(*x) for x in ((4, 2, 5), (5, 2, 6))]


# Note that this test is constructed with dt small enough relative to
# dx that these explicit methods stay stable -- while Irksome provides
# support for explicit schemes, we also caution users that there are
# no checks in the code that the method you are trying to run is
# actually sensible!
@pytest.mark.parametrize("butcher_tableau", peprks)
def test_1d_heat_dirichletbc(butcher_tableau):
# Boundary values
u_0 = Constant(2.0)
u_1 = Constant(3.0)

N = 10
x0 = 0.0
x1 = 10.0
msh = IntervalMesh(N, x1)
V = FunctionSpace(msh, "CG", 1)
MC = MeshConstant(msh)
dt = MC.Constant(1.0 / N)
t = MC.Constant(0.0)
(x,) = SpatialCoordinate(msh)

# Method of manufactured solutions copied from Heat equation demo.
S = Constant(2.0)
C = Constant(1000.0)
B = (x - Constant(x0)) * (x - Constant(x1)) / C
R = (x * x) ** 0.5
# Note end linear contribution
uexact = (
B * atan(t) * (pi / 2.0 - atan(S * (R - t)))
+ u_0
+ ((x - x0) / x1) * (u_1 - u_0)
)
rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact))
u = Function(V)
u.interpolate(uexact)
v = TestFunction(V)
F = (
inner(Dt(u), v) * dx
+ inner(grad(u), grad(v)) * dx
- inner(rhs, v) * dx
)
bc = [
DirichletBC(V, u_1, 2),
DirichletBC(V, u_0, 1),
]

luparams = {"mat_type": "aij", "ksp_type": "preonly", "pc_type": "lu"}

stepper = TimeStepper(
F, butcher_tableau, t, dt, u, bcs=bc,
solver_parameters=luparams,
stage_type="explicit"
)

t_end = 2.0
while float(t) < t_end:
if float(t) + float(dt) > t_end:
dt.assign(t_end - float(t))
stepper.advance()
t.assign(float(t) + float(dt))
# Check solution and boundary values
assert errornorm(uexact, u) / norm(uexact) < 10.0 ** -3
assert isclose(u.at(x0), u_0)
assert isclose(u.at(x1), u_1)
Loading