Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rexi updated #423

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4a21a43
rexi implementation
jshipton Aug 15, 2023
da360ed
test for rexi
jshipton Aug 15, 2023
4a66c1a
update logging and time stepper
nhartney Aug 15, 2023
f5598b5
put all REXI parameters in one class called RexiParameters
nhartney Aug 15, 2023
b491443
test for steps of the REXI approximation
nhartney Aug 15, 2023
c8d9d39
fix rexi to work with recent replacement changes and replace .split()…
jshipton Aug 15, 2023
7fd711c
Merge branch 'rexi_updated' of https://github.com/firedrakeproject/gu…
jshipton Aug 15, 2023
78bbef6
fix lint
jshipton Aug 16, 2023
d9448ce
fix lint for one of the test files
jshipton Aug 16, 2023
b38192f
test for REXI using linear shallow water waves on the plane scenario
nhartney Aug 16, 2023
916613d
test_linear_sw.py does the linear sw wave test with REXI and compares…
nhartney Aug 17, 2023
eae9430
fix firedrake split warning
jshipton Aug 17, 2023
9c2aee2
Merge branch 'rexi_updated' of https://github.com/firedrakeproject/gu…
jshipton Aug 17, 2023
9f24f14
remove fml
jshipton Mar 20, 2024
29cf593
fix definition of NullTerm
jshipton Mar 20, 2024
8a935c7
argh lint!
jshipton Mar 20, 2024
0be451e
still failing
jshipton Mar 20, 2024
759c673
allow transporting velocity to be indexed
jshipton Mar 21, 2024
15b53d4
Merge branch 'boussinesq_compressible_switch' of https://github.com/f…
jshipton Mar 21, 2024
f51aefa
some hacking for linear boussinesq equations
jshipton Mar 21, 2024
98769f4
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Apr 2, 2024
abe6229
fix to default linearisation for Boussinesq
jshipton Apr 2, 2024
9fd4f33
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Apr 2, 2024
91fc87f
start of compressible boussinesq test for rexi
jshipton Apr 2, 2024
d7de321
fix lint and test description
jshipton Apr 2, 2024
0f1d1f1
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Apr 3, 2024
5772d98
File -> VTKFile
jshipton Apr 3, 2024
e558533
improve docstring for rexi
jshipton Apr 4, 2024
03d6746
update kgo file for rexi shallow water test
jshipton Apr 17, 2024
1a3a03c
tighten tolerence in rexi shallow water test and some work on rexi bo…
jshipton Apr 18, 2024
3462752
add perturbation field to test
jshipton Apr 18, 2024
39a6db6
fix rexi compressible boussinesq test
jshipton Apr 18, 2024
fa34c94
argh lint!
jshipton Apr 18, 2024
c4df6e3
test tolerence was a bit harsh
jshipton Apr 18, 2024
e9475f1
parallel rexi test
jshipton Apr 18, 2024
5512707
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Apr 18, 2024
8848b5b
Move artifact upload to seperate step
JDBetteridge Aug 7, 2023
31ba546
Don't test for rank on write output
JDBetteridge Apr 19, 2024
7fdda01
Merge branch 'main' into rexi_updated
jshipton May 23, 2024
0416941
rexi - cpx.FunctionSpace
JHopeCollins Jul 22, 2024
e17d7fa
rexi - cpx.get/set_real/imag
JHopeCollins Jul 22, 2024
e5a9ef7
rexi - only allreduce the real components
JHopeCollins Jul 22, 2024
e3d4e78
rexi - use cpx for accumulating the sum
JHopeCollins Jul 22, 2024
c1a7484
rexi - use cpx.DirichletBC
JHopeCollins Jul 22, 2024
c5c827f
rexi - use cpx.split
JHopeCollins Jul 22, 2024
d574a37
rename rexi rhs for clarity
JHopeCollins Jul 23, 2024
c93dcd5
rexi - simplify building the complex system
JHopeCollins Jul 23, 2024
df86638
rexi - calculate alpha*M and rhs with cpx. Still have to refactor tau*L
JHopeCollins Jul 23, 2024
f8b1fd6
rexi - calculate tau*L with cpx.
JHopeCollins Jul 23, 2024
33e17fd
rexi - allow choosing cpx implementation via init arg.
JHopeCollins Jul 23, 2024
bdb9272
rexi - add copy of complex_proxy to gusto.
JHopeCollins Jul 23, 2024
355dbe6
add complex_proxy tests
JHopeCollins Jul 23, 2024
ac2f843
Merge pull request #520 from firedrakeproject/JHopeCollins/rexi_compl…
jshipton Jul 23, 2024
7b4540c
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Jul 23, 2024
1865563
fix import
jshipton Jul 23, 2024
9937599
rexi - implement form_rhs using form_mass.
JHopeCollins Jul 24, 2024
31559b1
Merge branch 'main' into rexi_updated
jshipton Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gusto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def perp(self, o, a):
from gusto.physics import * # noqa
from gusto.preconditioners import * # noqa
from gusto.recovery import * # noqa
from gusto.rexi import * # noqa
from gusto.spatial_methods import * # noqa
from gusto.time_discretisation import * # noqa
from gusto.timeloop import * # noqa
Expand Down
1 change: 1 addition & 0 deletions gusto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


__all__ = [
"Configuration",
"IntegrateByParts", "TransportEquationType", "OutputParameters",
"CompressibleParameters", "ShallowWaterParameters",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions",
Expand Down
3 changes: 2 additions & 1 deletion gusto/fml/form_manipulation_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from firedrake import Constant, Function


__all__ = ["Label", "Term", "LabelledForm", "identity", "drop", "all_terms",
__all__ = ["Label", "Term", "NullTerm", "LabelledForm",
"identity", "drop", "all_terms",
"keep", "subject", "name"]

# ---------------------------------------------------------------------------- #
Expand Down
2 changes: 2 additions & 0 deletions gusto/rexi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from gusto.rexi.rexi import * # noqa
from gusto.rexi.rexi_coefficients import * # noqa
231 changes: 231 additions & 0 deletions gusto/rexi/rexi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from gusto.rexi.rexi_coefficients import *
from firedrake import Function, TrialFunctions, TestFunctions, \
Constant, DirichletBC, \
LinearVariationalProblem, LinearVariationalSolver, MixedFunctionSpace
from gusto import (replace_subject, drop, time_derivative,
all_terms, replace_test_function, prognostic,
Term, NullTerm, linearisation, subject,
replace_trial_function)
from firedrake.formmanipulation import split_form


class Rexi(object):
"""
Class defining the solver for the system

(A_n + tau L)V_n = U

required for computing the matrix exponential as described in notes.pdf

:arg equation: :class:`.Equation` object defining the equation set to
be solved
:arg rexi_parameters: :class:`.Equation` object
:arg solver_parameters: dictionary of solver parameters. Default None,
which results in the default solver parameters defined in the equation
class being used.
:arg manager: :class:`.Ensemble` object containing the space and ensemble
subcommunicators

"""
def __init__(self, equation, rexi_parameters, *, solver_parameters=None,
manager=None):

residual = equation.residual.label_map(
lambda t: t.has_label(linearisation),
map_if_true=lambda t: Term(t.get(linearisation).form, t.labels),
map_if_false=drop)
residual = residual.label_map(
all_terms,
lambda t: replace_trial_function(t.get(subject))(t))

# Get the Rexi Coefficients, given the values of h and M in
# rexi_parameters
self.alpha, self.beta, self.beta2 = RexiCoefficients(rexi_parameters)

self.manager = manager

# define the start point of the solver loop (idx) and the
# number of solvers (N) for this process depending on the
# total number of solvers (nsolvers) and how many ensemble
# processes (neprocs) there are.
nsolvers = len(self.alpha)
if manager is None:
# if running in serial we loop over all the solvers, from
# 0: nsolvers
self.N = nsolvers
self.idx = 0
else:
rank = manager.ensemble_comm.rank
neprocs = manager.ensemble_comm.size
m = int(nsolvers/neprocs)
p = nsolvers - m*neprocs
if rank < p:
self.N = m+1
self.idx = rank*(m+1)
else:
self.N = m
self.idx = rank*m + p

# set dummy constants for tau and A_i
self.ar = Constant(1.)
self.ai = Constant(1.)
self.tau = Constant(1.)

# set up functions, problem and solver
W_ = equation.function_space
self.w_out = Function(W_)
spaces = []
for i in range(len(W_)):
spaces.append(W_[i])
spaces.append(W_[i])
W = MixedFunctionSpace(spaces)
self.U0 = Function(W)
self.w_sum = Function(W)
self.w = Function(W)
self.w_ = Function(W)
tests = TestFunctions(W)
trials = TrialFunctions(W)
tests_r = tests[::2]
tests_i = tests[1::2]
trials_r = trials[::2]
trials_i = trials[1::2]

ar, ai = self.ar, self.ai
a = NullTerm
L = NullTerm
for i in range(len(W_)):
ith_res = residual.label_map(
lambda t: t.get(prognostic) == equation.field_names[i],
lambda t: Term(
split_form(t.form)[i].form,
t.labels),
map_if_false=drop)

mass_form = ith_res.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)

m = mass_form.label_map(
all_terms,
replace_test_function(tests_r[i]))
a += (
(ar + ai) * m.label_map(all_terms,
replace_subject(trials_r[i], old_idx=i))
+ (ar - ai) * m.label_map(all_terms,
replace_subject(trials_i[i], old_idx=i))
)

L += (
m.label_map(all_terms, replace_subject(self.U0.subfunctions[2*i], i))
+ m.label_map(all_terms, replace_subject(self.U0.subfunctions[2*i+1], old_idx=i))
)

m = mass_form.label_map(
all_terms,
replace_test_function(tests_i[i]))
a += (
(ar - ai) * m.label_map(all_terms,
replace_subject(trials_r[i], old_idx=i))
+ (-ar - ai) * m.label_map(all_terms,
replace_subject(trials_i[i], old_idx=i))
)

L += (
m.label_map(all_terms,
replace_subject(self.U0.subfunctions[2*i], i))
- m.label_map(all_terms,
replace_subject(self.U0.subfunctions[2*i+1], i))
)

L_form = ith_res.label_map(
lambda t: t.has_label(time_derivative),
drop)

Lr = L_form.label_map(
all_terms,
replace_test_function(tests_r[i]))
a -= self.tau * Lr.label_map(all_terms,
replace_subject(trials_r))
a -= self.tau * Lr.label_map(all_terms,
replace_subject(trials_i))

Li = L_form.label_map(
all_terms,
replace_test_function(tests_i[i]))
a -= self.tau * Li.label_map(all_terms,
replace_subject(trials_r))
a += self.tau * Li.label_map(all_terms,
replace_subject(trials_i))

a = a.label_map(lambda t: t is NullTerm, drop)
L = L.label_map(lambda t: t is NullTerm, drop)

if hasattr(equation, "aP"):
aP = equation.aP(trial, self.ai, self.tau)
else:
aP = None

# Boundary conditions (assumes extruded mesh)
# BCs are declared for the plain velocity space. As we need them in
# extended mixed problem, we replicate the BCs but for subspace of W
bcs = []
for bc in equation.bcs['u']:
bcs.append(DirichletBC(W.sub(0), bc.function_arg, bc.sub_domain))
bcs.append(DirichletBC(W.sub(1), bc.function_arg, bc.sub_domain))

rexi_prob = LinearVariationalProblem(a.form, L.form, self.w, aP=aP,
bcs=bcs,
constant_jacobian=False)

# if solver_parameters is None:
# solver_parameters = equation.solver_parameters

self.solver = LinearVariationalSolver(
rexi_prob, solver_parameters=solver_parameters)

def solve(self, x_out, x_in, dt):
"""
Solve method for approximating the matrix exponential by a
rational sum. Solves

(A_n + tau L)V_n = U

multiplies by the corresponding B_n and sums over n.

:arg U0: the mixed function on the rhs.
:arg dt: the value of tau

"""

# assign tau and U0 and initialise solution to 0.
self.tau.assign(dt)
Uin = x_in.subfunctions
U0 = self.U0.subfunctions
for i in range(len(Uin)):
U0[2*i].assign(Uin[i])
self.w_.assign(0.)
w_ = self.w_.subfunctions
w = self.w.subfunctions

# loop over solvers, assigning a_i, solving and accumulating the sum
for i in range(self.N):
j = self.idx + i
self.ar.assign(self.alpha[j].real)
self.ai.assign(self.alpha[j].imag)
self.solver.solve()
for k in range(len(Uin)):
wk = w_[2*k]
wk += Constant(self.beta[j].real)*w[2*k] - Constant(self.beta[j].imag)*w[2*k+1]

# in parallel we have to accumulate the sum over all processes
if self.manager is not None:
self.manager.allreduce(self.w_, self.w_sum)
else:
self.w_sum.assign(self.w_)

w_sum = self.w_sum.subfunctions
w_out = self.w_out.subfunctions
for i in range(len(w_out)):
w_out[i].assign(w_sum[2*i])

x_out.assign(self.w_out)
Loading
Loading